一天没更新是因为花了点时间调试Triton Kernel。这部分的开发比我预期的要慢,主要还是对于pytorch的Tensor的操作不是非常熟悉,栽了几个比较大的坑,因此反复debug了好久。这里也记录一下,希望能以此提醒自己,也能够帮助后来者避坑。
这部分的思路其实前面已经完全阐述完成,主要是讲讲主要完成的工作和遇到的坑有哪些。
设计模式
这部分主要是为了后面的CUDA Graph的实现做准备。总的来说,指导思想只有一个:静态化所有核心前向的数据指针和形状,同时避免任何条件分支。这是因为CUDA Graph的本质是重放:将既定的指针形状、顺序、流程硬编码在GPU上,反复执行,从而避免CPU下发指令进行Kernel启动的开销。这部分目前还没有用上,但已经提前做了准备。
但众所周知,LLM的推理框架本身就需要两个截然不同的需求:Prefill和Decode。在vLLM等主流实践当中,一般都需要将两者完全分离,前者不使用CUDA Graph而后者使用。这么做的原因主要是Prefill的形状高度不固定:为一个比较短的prompt input去padding一个巨大无比的形状来确保绝对不越界以允许CUDA Graph明显是相当不方便的。这也是chunked Prefill产生的动机之一:分次做的话,很大程度上就形状固定了。如果将两者放在一起,is_prefill的条件判断本身也会让CUDA Graph因为分支发散而完全不成立。
这倒是让我想起来之前看到的一篇论文:《Medusa: Accelerating Serverless LLM Inference with Materialization》。这也是我这学期在上的课的老师陆游游老师团队做的一项工作。它的主要贡献是,直接把CUDA Graph的结构给解析了出来,然后通过物化和指针替换的方式动态的使得物化在主存或者SSD上的CUDA Graph不会因为重新加载之后cublas等闭源库的函数加载位置改变、数据指针的位置改变而失效,从而实现LLM推理服务从非加载状态的完全冷启动,同时还能保持启动的高效率,免去每次重新记录和重放的时间overhead消耗。考虑到CUDA Graph这东西能够在一个自研的框架上跑通本身就要花费巨量的心血,直接将其逆向的工作能够发在 ASPLOS 2025 上当然也是可以想象的。很推荐大家去读!
Triton Kernel的实现
目前显式写出的Triton Kernel有两个:一个是KV Cache 写入算子 (store_kvcache),一个是分页解码注意力(paged_decode_attention)。它们分别负责:
1、store_kvcache在Prefill阶段和Decode根据block Manager分配的块表正确的写入相应的KV cache,Prefill阶段写入seq_len个而Decode阶段固定写入1个。
2、paged_decode_attention则是Decode阶段专用的Kernel,负责根据页表正确取出相应的KV cache,同时完成Flash Attention模式的不将中间矩阵物化的计算流程。
此外需要特别注意的是,虽然传入的是Tensor,不过实际上Triton内部的操作模式更接近C++:传入的Tensor实际上是隐式转换成带数据类型的指针来处理的。这方面要格外注意的是不同指针的数据类型:如果数据类型不同,那么指针的算术操作造成的物理地址偏移量实际上是完全不同的,而且写入的东西也很可能不匹配。如果不能确定自己到底在写什么的话,就多使用tl.print()然后只针对pid==0的情况进行打印来确认吧。
下面是源代码:
1、store_kvcache
| |
这个Kernel看上去非常简单,但是在优化的时候有一个大坑,后面讲。聪明的读者可能看到 stride 的时候就想到了吧。
2、paged_decode_attention
| |
这部分本身没有什么概念上的困难。需要格外注意的是,调试的时候其报错是静默的:只会告诉你错在这里了,而不是告诉你这个Kernel里到底错在哪一行。因此,要多打印!
另一个非常要注意的关键点是Triton会默认执行死代码消除的优化(DCE)。这一点对于跑高性能是好事,但是对于调试极其不利。你可能会想,我注释掉后面一部分让Kernel报错的,一直二分直到找到问题所在,行不行呢?答案是不行,至少不是简单的行。如果你注释了后面一部分操作,而最后没有把前面操作的那部分store到某个地方的话,Triton编译器会直接把它们全部消除掉。那如果本身错误是发生在这部分,那么看上去你只是注释掉了最后几行就还是跑通了,但实际上并不能因此定位到错误所在。
踩坑点和注意事项
1、显存连续性问题
在第一版实现的时候,store_kvcache_kernel是没有传入stride参数的。这在Q、K、V三个权重矩阵分开计算的时候一点问题也没有。但是,在后来的计算流优化当中,为了减少Kernel的launch开销和反复读显存的开销,采用了一个很简单的优化实现:把$W_Q$、$W_K$、$W_V$拼接起来一起和输入$x$做GEMM。在两者同时进行的时候,出现了极其难以排查的乱码,几乎花费了三到五个小时才发现原因。这到底是为什么?看这段代码:
| |
看上去没什么问题?
但如果把中间的print注释掉,你会发现,v.is_contiguous()的结果是false。如果改成k,结果又是true。这到底是为什么?因为:k经过了RoPE.apply_rope,在计算的过程中,pytorch自动帮你完成了向量的深拷贝,从而完成了向量的连续化重整,而v的两个操作,.view()和.reshape()都是不触发深拷贝的。也就是说,在它传入store_kvcache这个warpper的时候,它是仅仅是一个视图,在物理上不连续。
那这为什么会造成问题?因为之前说的:Triton内部的操作模式更接近C++:传入的Tensor实际上是隐式转换成带数据类型的指针来处理的。这意味着算术偏移不会考虑任何视图,操作的直接就是底层的物理地址。这和不连续明显是冲突的,但却又在整个大Tensor的数据范围内。因此就会发生,“数据不对,但也没有报错,看k的数据是完全对得上的,无论如何都查不出毛病”的问题。因此,pytorch的Tensor操作当中一定要小心视图和实际存储布局的区别问题。最后,是通过添加stride,重写底层的指针偏移来解决的。如果从头在C++本身开发,反倒没有这个问题了。
2、context_lens和seq_len的区别
这个我一时间还真不能完全精确的说明区别。为了避免对读者造成误会,直接请Gemini老师帮忙总结一下:
在手搓大模型推理引擎时,这两个变量在命名上极具迷惑性,如果不加区分,很容易在 Prefill 和 Decode 切换的瞬间触发经典的“差一错误”(Off-by-one),导致模型吐出的第一个新词直接变成乱码。简单来说,它们的物理含义完全不同:
seq_len(当前处理长度):指的是模型当前这一步 Forward 到底吃进去了几个 token。- 在 Prefill 阶段,一次性处理整个 Prompt,所以
seq_len = N。 - 在 Decode 阶段,因为我们是逐字生成,每次只喂给模型一个最新的词,所以
seq_len永远等于 1。
- 在 Prefill 阶段,一次性处理整个 Prompt,所以
context_lens(上下文总长度):这是传给底层 Paged Attention 算子(如 Triton Kernel)的关键参数。它代表当前的 Query 到底需要和 多少个 token(包含它自己) 去计算注意力得分。
换言之,一个简单的自我思想检验:在Prefill刚刚结束,第一轮Decode开始的时候,传入Decode的参数应该是:seq_len=1,context_lens=N+1,其中N是prompt中的token数量。
3、复制粘贴的问题
孩子们,用ai帮忙写测试脚本的时候一定要看看自己复制了几遍,否则会发现自己的推理框架莫名其妙的总是会在Prefill结束之后多出一个token而找不到任何原因😭
性能Profiling
用nsys进行了GPU侧性能的Profiling,模式是eager,无编译的自动算子融合。结果如下:
Kernel 时间占比 实例数 单步耗时
──────────────────────────────────────────────────────────────
GEMM (gemvx, 主) 63.1% 1700 ~6.8ms
GEMM (gemvx, 次) 22.2% 560 ~2.4ms
cutlass WMMA (gate_up fused) 5.4% 560 ~0.6ms
RMSNorm elementwise 2.0% 3380 ~0.2ms
BinaryFunctor (residual add) 1.0% 1700 ~0.1ms
CUDAFunctor_add 0.9% 2240 ~0.1ms
Decode_Paged_GQAAttention 0.9% 560 ~0.1ms
RMSNorm reduce 0.9% 1140 ~0.1ms
CatArrayBatchedCopy 0.8% 1120 ~0.08ms
store_kvcache_kernel 0.2% 560 ~0.02ms
──────────────────────────────────────────────────────────────
GPU 总时间(20步) ~214ms 单步 ~10.7ms
按3.15GB(3GB权重+0.15GB的KV cache)的Memory Bound情况来计算,单步的理论最优化时间是:
$$ (3.15GB/token)/(382GB/s)=0.00825s=8.25ms $$性能利用率已经达到77%,与此同时,整个单步Decoding的时延居然达到33ms左右,远远长于GPU侧的耗时。这主要是因为pytorch侧的时间非常长,反复调用产生了巨量的overhead。在我使用的WSL系统中,这个调用还会产生更大的虚拟化开销。因此目前的优化的性能瓶颈已经不是在GPU端了。这也是为什么接下来要做CUDA Graph。