这篇简单的总结一下我们目前达成的里程碑、我们已经实现的功能和feature,同时也总结一下遇到的主要问题,提醒自己、也帮助后来人。
里程碑记录
注:有AI辅助排版。
开发时间线
从第一行代码到单卡收尾,总计11 天(3月21日 — 3月31日)。
Day 1–2(3/21–3/22):从零搭建
从sampler.py开始,逐层手写Qwen2.5-1.5B的所有算子:RMSNorm、RoPE(复数形式)、GQA Attention、SwiGLU FFN,每个模块写完立刻和 HuggingFace逐tensor做torch.allclose数值对齐。完成权重加载、Engine、Sampler,在 GPU 上跑通第一次对话输出。
v0.1-naive-no-cache — 21tok/s,无 KV Cache的baseline。
Day 3(3/23):KV Cache + Continuous Batching
实现NaiveKVCache,验证KV Cache在长序列下的加速效果(seq_len=3200时从1.8 tok/s提升到15.3tok/s,8.5x)。紧接着实现Scheduler和Continuous Batching,请求动态进出batch。
v0.2-naive-kvcache — 第一次 profiling,建立各算子耗时的数据基线。
v0.3-continuous-batching — Scheduler + Engine多请求调度跑通。
Day 4–5(3/24–3/25):PagedAttention 数据结构 + 模型重构
实现BlockManager(全局block pool,分配/回收)和PagedKVCache(block_table + 逻辑→物理映射)。GQA Attention完全重写,拆分prefill和decode两条路径。RoPE 从复数形式重构为查表式实数实现。端到端数值验证通过。
v0.4-paged-attention — PagedAttention gather实现,性能和NaiveKVCache持平(这是因为GEMM主导,gather开销被掩盖)。第二次profiling。
Day 5–7(3/25–3/27):Triton Kernel 密集期
手写两个Triton kernel:store_kvcache_kernel(2D grid,token × head)和Decode_Paged_GQAAttention_Kernel(online softmax + block table paged access)。完成QKV Fused Projection和SwiGLU gate_up Fused。将model.forward重构为纯tensor接口,所有Python对象操作移出forward,为CUDA Graph和torch.compile扫清障碍。
v0.5-triton-paged-attention — 23.2 tok/s。第三次profiling(nsys),首次量化 CPU launch overhead(851 kernels × 15μs = 13.5ms/步)。
v0.6-slot-mapping-store-kvcache — 28.4 tok/s,D2H memcpy 从560次降至0次。
Day 8–9(3/28–3/29):CUDA Graph
先做实验验证CUDA Graph的收益预期(30 kernel串联实测6.62x加速),然后实现完整的静态buffer预分配 + capture/replay机制。Decode路径走CUDA Graph,Prefill保持eager。
v0.7-cuda-graph — 87.3 tok/s,4.2x加速。第四次profiling。
Day 10–11(3/30–3/31):Triton Kernel 收尾 + 调优
手写Fused RoPE + Paged KV Cache Store kernel(+5 tok/s)、Fused SwiGLU kernel(+2 tok/s)、Triton RMSNorm(+2~3 tok/s)。尝试Fused Add RMSNorm,确认为负优化后回退。CPU侧block_table缓存优化。Engine收尾整合。
v0.8-97tok-cuda-graph — 96.94 tok/s,超过vLLM在同硬件上的95tok/s。
优化阶段 延迟 tok/s vs 起点
─────────────────────────────────────────────────────────
v0.1 Naive(无 KV cache) 51ms 21~1.5 1.0x
v0.2 NaiveKVCache 50ms 20 1.0x
v0.4 PagedAttention (gather) 48ms 20.8 1.0x
v0.5 Triton Paged Attention 43ms 23.2 1.1x
v0.6 + QKV/SwiGLU fused 35ms 28.4 1.35x
v0.7 + CUDA Graph 11.45ms 87.3 4.2x
v0.8 + Triton Kernel 收尾 10.32ms 96.94 4.6x
功能与特性实现
模型层
从零手写 Qwen2.5-1.5B 完整前向推理,不依赖 HuggingFace 的模型实现,仅使用其 Tokenizer 和权重文件。包括 RMSNorm、RoPE(查表式实数实现)、Grouped Query Attention(12 Q heads / 2 KV heads)、SwiGLU FFN,全部模块经过与 HuggingFace 参考输出的逐 tensor 数值对齐验证。支持 bfloat16 精度推理。
KV Cache 管理
实现了完整的 PagedAttention 内存管理体系。BlockManager 在启动时一次性预分配全局 KV Cache 物理内存池,按固定大小的 block(16 tokens)进行分配和回收,通过逻辑→物理的 block_table 映射实现非连续内存的透明访问,彻底消除显存碎片。每个请求持有独立的 PagedKVCache 对象,记录自身的 block_table 和序列长度,生命周期与请求绑定。
调度与 Batching
实现了 Continuous Batching 调度器,请求可以在任意时刻动态加入和退出 batch,不需要等待凑满固定 batch size。调度器维护 waiting、prefilling、decoding、finished 四个队列,每步自动完成状态流转。Prefill 和 Decode 走独立的执行路径,Prefill 逐请求处理(序列长度动态),Decode 支持多请求 batch 并行。
自写 Triton Kernel
共 5 个手写 Triton kernel:
store_kvcache_kernel:2D grid(token × head),通过 slot_mapping 将新计算的 K/V 直接写入非连续的物理 Cache 块,支持 stride 处理非连续内存布局。
Decode_Paged_GQAAttention_Kernel:Paged Decode Attention 的核心计算,在 kernel 内部通过 block_table 索引逐块读取分散存储的 KV,使用 online softmax 累积 attention 结果,天然处理 GQA 的 head 映射,零拷贝无需 gather。
fused_rope_and_cache_store_kernel:将 RoPE 旋转位置编码和 KV Cache 写入融合为单个 kernel,Q 在寄存器内完成旋转后写回连续显存,K 和 V 完成旋转后直接通过 slot_mapping 写入分散的物理 Cache 块,消除中间缓冲。
fused_silu_mul_kernel:将 SwiGLU 的 Sigmoid、门控乘法和 up 分支的逐元素乘法融合为单次显存读写,三次 I/O 降为一次。
_rmsnorm_kernel:将 RMSNorm 的 6 个独立 PyTorch 算子合并为单个 kernel,每行一个 program,全程 float32 计算后转回原始 dtype 写出。
算子融合与模型优化
QKV Fused Projection:三个独立的 q/k/v 矩阵乘法合并为单次 GEMM,权重在加载时拼接,输出通过 split 拆分。SwiGLU gate_up Fused:同理将 gate_proj 和 up_proj 合并。rotate_half 的 torch.cat 消除,改为 in-place 操作。model.forward 重构为纯 tensor 接口,不接受任何 Python 对象,所有 cache 操作留在 Engine 层,使 forward 对 torch.compile 和 CUDA Graph 完全透明。
CUDA Graph
Decode 路径完整接入 CUDA Graph。Engine 启动时预分配固定形状的静态 buffer(input_ids、slot_mapping、position_ids、block_table、context_lens),每步只原地更新 buffer 的值,不重新分配内存。CUDA Graph 一次 capture 后反复 replay,将每步数百次 kernel launch 压缩为单次提交。Prefill 路径保持 eager 执行(序列长度动态,不适合 capture)。提供 use_cuda_graph 开关,可随时回退 eager 模式调试。
采样
支持 Greedy、Temperature Sampling 和 Top-p (Nucleus) Sampling 三种采样策略,统一接口。
开发过程遇到的错误汇总
简单的总结了一下,整个单卡框架开发过程中遇到的主要的、难以排查的错误,我还记得的就这么几个——如果我不记得了,应该是没有折磨我太久,大家应该能解决。
注:有AI辅助排版。
类型一:KV Cache 写入位置错误
症状: 生成乱码,词语乱序重复,或反复覆盖同一位置。
出现了两次:
第一次(早期 PagedAttention):
原因:context_lens 比实际少 1
decode 时写入新 token 后,读取范围没有包含刚写入的 token
新 token 的 query attend 不到自己
第二次(slot_mapping 重构后):
原因:get_prefill_slot_mapping 里
cache_block_index 是 tensor,用 tensor 作为 dict key
Python dict 用 tensor 作 key 行为未定义,查到错误的物理 id
slot_mapping 返回 [0,1,2,3,4] 而不是真实物理地址
prefill 写入物理 block X,decode 读取物理 block Y
加 .item() 后修复,但这降低了性能,后续整体重构了
类型二:非连续内存导致 Triton Kernel 越界
症状: cudaErrorIllegalAddress,kernel 崩溃,或者生成完全不可理解(而不是简单反复重复)的乱码。
出现了两次:
第一次(paged_decode_attention):
原因:kv_cache[layer_idx] 是大 tensor 的 view
Triton 拿到 data_ptr 后按连续内存计算偏移
实际 stride 和预期不符,导致越界访问
加 .contiguous() 后修复
第二次(store_kvcache,极其隐蔽):
原因:apply_rope 之后 k/v 是非连续 tensor
view/reshape 假设连续内存
Triton kernel 按错误地址写入
根本修法:给 store_kvcache_kernel 传入 stride 参数
用 stride 计算偏移而不是假设连续
类型三:模型前向传播的过程实现在反复修改中出错
症状: 生成不完全的乱码,词语乱序重复,有意义但不完全有意义。
出现了一次,很隐蔽:
原因:q 相关的 forward 实现在不同路径里不一致,在 Decode 路径中遗漏了 RoPE 旋转状态
- Eager 路径(forward)里,q 被 apply_rope 原地覆盖更新,传给 Attention 的是正确旋转后的 q
- Graph 路径(forward_decode)引入了融合算子 fused_decode_rope_and_cache,它返回了新的 q_rot,但未被后续使用
- 原始未旋转的 q 直接参与了注意力计算,导致模型在 Decode 时彻底丢失相对位置信息(RoPE 失效),引发无脑复读和乱码
出现位置:
- GQAAttention.forward_decode 方法的末尾处
- 算子调用错误地写成了 output = self._decode_attention(q, kv_cache_k, ...)
- 实际应该传入旋转后的变量,改为 output = self._decode_attention(q_rot, kv_cache_k, ...)
类型四:Tensor Shape 错位
症状: RuntimeError: mat1 and mat2 shapes cannot be multiplied。
出现了多次,都是同一个模式:
原因:q 的 shape 在反复迭代修改后,在不同路径里不一致或者前后不一致。例如:
- GQAAttention 内部 q 是 (B, seq, n_heads, head_dim)
- paged_decode_attention wrapper 期望 (B, n_heads, 1, head_dim)
- transpose 和 reshape 的顺序搞错导致维度错位
出现位置:
- _decode_attention 里 q 传给 Triton 之前忘记 transpose
- wrapper 里 unpack 顺序写成 B, _, N_HEAD, HEAD_DIM
实际应该 B, N_HEAD, _, HEAD_DIM
类型五:seq_len 更新时机错误
症状: 生成内容在第 2 步之后开始乱码,第 1 步正确。生成的语句基本有正确的语义,但就是Decoding阶段开始的时候会出现重复、或者突然的错误。
原因:decode 每步用 context_lens = c.seq_len + 1
但 seq_len 在 forward 之前已经被 prepare_decode_step
或其他逻辑提前更新,导致 context_lens 多 1 或少 1
另一次:prefill 后 _seq_len += len + 1(多加了 1)
第一个 decode step 的 position 偏移 1,之后全部错位