Pico-vLLM 开发日志 #6 Triton Kernel和代码重构

一天没更新是因为花了点时间调试Triton Kernel。这部分的开发比我预期的要慢,主要还是对于pytorch的Tensor的操作不是非常熟悉,栽了几个比较大的坑,因此反复debug了好久。这里也记录一下,希望能以此提醒自己,也能够帮助后来者避坑。

这部分的思路其实前面已经完全阐述完成,主要是讲讲主要完成的工作和遇到的坑有哪些。

设计模式

这部分主要是为了后面的CUDA Graph的实现做准备。总的来说,指导思想只有一个:静态化所有核心前向的数据指针和形状,同时避免任何条件分支。这是因为CUDA Graph的本质是重放:将既定的指针形状、顺序、流程硬编码在GPU上,反复执行,从而避免CPU下发指令进行Kernel启动的开销。这部分目前还没有用上,但已经提前做了准备。

但众所周知,LLM的推理框架本身就需要两个截然不同的需求:Prefill和Decode。在vLLM等主流实践当中,一般都需要将两者完全分离,前者不使用CUDA Graph而后者使用。这么做的原因主要是Prefill的形状高度不固定:为一个比较短的prompt input去padding一个巨大无比的形状来确保绝对不越界以允许CUDA Graph明显是相当不方便的。这也是chunked Prefill产生的动机之一:分次做的话,很大程度上就形状固定了。如果将两者放在一起,is_prefill的条件判断本身也会让CUDA Graph因为分支发散而完全不成立。

这倒是让我想起来之前看到的一篇论文:《Medusa: Accelerating Serverless LLM Inference with Materialization》。这也是我这学期在上的课的老师陆游游老师团队做的一项工作。它的主要贡献是,直接把CUDA Graph的结构给解析了出来,然后通过物化和指针替换的方式动态的使得物化在主存或者SSD上的CUDA Graph不会因为重新加载之后cublas等闭源库的函数加载位置改变、数据指针的位置改变而失效,从而实现LLM推理服务从非加载状态的完全冷启动,同时还能保持启动的高效率,免去每次重新记录和重放的时间overhead消耗。考虑到CUDA Graph这东西能够在一个自研的框架上跑通本身就要花费巨量的心血,直接将其逆向的工作能够发在 ASPLOS 2025 上当然也是可以想象的。很推荐大家去读!

Triton Kernel的实现

目前显式写出的Triton Kernel有两个:一个是KV Cache 写入算子 (store_kvcache),一个是分页解码注意力(paged_decode_attention)。它们分别负责:

1、store_kvcache在Prefill阶段和Decode根据block Manager分配的块表正确的写入相应的KV cache,Prefill阶段写入seq_len个而Decode阶段固定写入1个。

2、paged_decode_attention则是Decode阶段专用的Kernel,负责根据页表正确取出相应的KV cache,同时完成Flash Attention模式的不将中间矩阵物化的计算流程。

此外需要特别注意的是,虽然传入的是Tensor,不过实际上Triton内部的操作模式更接近C++:传入的Tensor实际上是隐式转换成带数据类型的指针来处理的。这方面要格外注意的是不同指针的数据类型:如果数据类型不同,那么指针的算术操作造成的物理地址偏移量实际上是完全不同的,而且写入的东西也很可能不匹配。如果不能确定自己到底在写什么的话,就多使用tl.print()然后只针对pid==0的情况进行打印来确认吧。

下面是源代码:

1、store_kvcache

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@triton.jit
def store_kvcache_kernel(
    k_ptr, v_ptr,              # 源 K/V: (total_tokens, n_kv_heads, HEAD_DIM)
    k_cache_ptr, v_cache_ptr,  # 目标 Cache: (num_blocks, n_kv_heads, block_size, HEAD_DIM)
    slot_mapping_ptr,          # (total_tokens,) 每个 token 的物理 slot 编号
    stride_k_token, stride_k_head, stride_k_dim,
    stride_v_token, stride_v_head, stride_v_dim,
    N_KV_HEADS: tl.constexpr, # ← constexpr
    BLOCK_SIZE: tl.constexpr, # ← constexpr
    HEAD_DIM: tl.constexpr,
):
    # 映射 2D Grid
    pid_token = tl.program_id(0)
    pid_head = tl.program_id(1)

    # 读取物理 slot
    slot = tl.load(slot_mapping_ptr + pid_token)
    block_id = slot // BLOCK_SIZE
    offset = slot % BLOCK_SIZE

    # 计算源数据 (K/V) 的一维内存偏移
    dim_offsets = tl.arange(0, HEAD_DIM)
    src_offset = pid_token * stride_k_token \
               + pid_head  * stride_k_head  \
               + dim_offsets * stride_k_dim

    k_vec = tl.load(k_ptr + src_offset)
    v_src = pid_token * stride_v_token \
          + pid_head  * stride_v_head  \
          + dim_offsets * stride_v_dim
    v_vec = tl.load(v_ptr + v_src)
    
    # 计算目标数据 (KV Cache) 的一维内存偏移
    dst_offset = (block_id * N_KV_HEADS * BLOCK_SIZE * HEAD_DIM) + \
                 (pid_head * BLOCK_SIZE * HEAD_DIM) + \
                 (offset * HEAD_DIM) + \
                 dim_offsets

    # 写入 Cache
    tl.store(k_cache_ptr + dst_offset, k_vec)
    tl.store(v_cache_ptr + dst_offset, v_vec)

这个Kernel看上去非常简单,但是在优化的时候有一个大坑,后面讲。聪明的读者可能看到 stride 的时候就想到了吧。

2、paged_decode_attention

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@triton.jit
def Decode_Paged_GQAAttention_Kernel(
        q                         ,  # (B, n_heads, 1, head_dim)         query,decode 每步只有 1 个 token
        k_cache                   ,  # (num_blocks, n_kv_heads, block_size, head_dim)  全局 K cache
        v_cache                   ,  # (num_blocks, n_kv_heads, block_size, head_dim)  全局 V cache
        block_table               ,  # (B, MAX_BLOCKS_PER_SEQ)  int32,每个请求的物理块 id,-1 表示未分配
        context_lens              ,  # (B,)             int32,每个请求当前的有效 token 数
        scale              ,         # 1.0 / sqrt(head_dim)
        out,
        
        # Meta-parameters
        # 元参数
        MAX_BLOCKS_PER_SEQ: tl.constexpr,  # 启动时固定,不是运行时变量
        BLOCK_SIZE: tl.constexpr,  #
        HEAD_DIM: tl.constexpr,  #
        N_KV_HEAD: tl.constexpr,
        N_HEAD: tl.constexpr,
    ):               # (B, n_heads, 1, head_dim)
    # grid = (B, n_heads)
    # 每个 program 处理一个 (batch, head) 对
    # program_id[0] = batch_idx
    # program_id[1] = head_idx
    
    pid_batch = tl.program_id(axis=0)
    pid_head = tl.program_id(axis=1)
    kv_head_idx = (pid_head // (N_HEAD // N_KV_HEAD))

    m = float('-inf')
    l = float(0)
    o = tl.zeros((HEAD_DIM, ), dtype=tl.float32)
    
    
    q_ptrs = q + pid_batch * (HEAD_DIM * N_HEAD) + pid_head * (HEAD_DIM) + tl.arange(0, HEAD_DIM)
    q_vec = tl.reshape(tl.load(q_ptrs), (HEAD_DIM, 1))

    context_len = tl.load(context_lens + pid_batch)
    max_block_index = tl.cdiv(context_len, BLOCK_SIZE)  # 向上取整
    offs_kv = tl.arange(0, BLOCK_SIZE * HEAD_DIM)
    
    for block_idx in range(0, max_block_index):
        
        physical_idx = tl.load(block_table + pid_batch * MAX_BLOCKS_PER_SEQ + block_idx)
        physical_idx = tl.maximum(physical_idx, 0).to(tl.int64)
        base = (physical_idx * N_KV_HEAD * BLOCK_SIZE * HEAD_DIM + kv_head_idx* BLOCK_SIZE * HEAD_DIM)

        # 加载时 mask 掉超出 context_len 的 token
        token_start = block_idx * BLOCK_SIZE
        valid_in_block = tl.minimum(BLOCK_SIZE, context_len - token_start)
        kv_token_mask = tl.arange(0, BLOCK_SIZE * HEAD_DIM) < valid_in_block * HEAD_DIM

        k_ptrs = k_cache + base + offs_kv
        # k_block: (block_size, head_dim)
        k_block = tl.load(k_ptrs, mask = kv_token_mask, other=0.0)
        k_block = tl.reshape(k_block, (BLOCK_SIZE, HEAD_DIM))
        v_ptrs = v_cache + base + offs_kv
        v_block = tl.load(v_ptrs, mask = kv_token_mask, other=0.0)
        v_block = tl.reshape(v_block, (BLOCK_SIZE, HEAD_DIM))

        # mask 最后一个 block 的无效 token
        valid = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < context_len
        q_row = tl.reshape(q_vec, (1, HEAD_DIM))           # (1, HEAD_DIM)
        s = tl.sum(k_block * q_row, axis=1)                 # (BLOCK_SIZE,)
        s = s.to(tl.float32) * scale
        s = tl.where(valid, s, float('-inf'))

        m_new = tl.maximum(m, tl.max(s))
        alpha = tl.exp(m - m_new)           # 旧的缩放因子
        p = tl.exp(s - m_new)          # 当前 block 的权重

        l = l * alpha + tl.sum(p)
        p_col = tl.reshape(p, (BLOCK_SIZE, 1))             # (BLOCK_SIZE, 1)
        o = o * alpha + tl.sum(p_col * v_block, axis=0)    # (HEAD_DIM,)
        m = m_new

    o = o / l  # (1, HEAD_DIM)
    o_casted = tl.cast(o, out.dtype.element_ty)
    
    tl.store(out + pid_batch * (HEAD_DIM * N_HEAD) + pid_head * (HEAD_DIM) + tl.arange(0, HEAD_DIM), o_casted)

这部分本身没有什么概念上的困难。需要格外注意的是,调试的时候其报错是静默的:只会告诉你错在这里了,而不是告诉你这个Kernel里到底错在哪一行。因此,要多打印!

另一个非常要注意的关键点是Triton会默认执行死代码消除的优化(DCE)。这一点对于跑高性能是好事,但是对于调试极其不利。你可能会想,我注释掉后面一部分让Kernel报错的,一直二分直到找到问题所在,行不行呢?答案是不行,至少不是简单的行。如果你注释了后面一部分操作,而最后没有把前面操作的那部分store到某个地方的话,Triton编译器会直接把它们全部消除掉。那如果本身错误是发生在这部分,那么看上去你只是注释掉了最后几行就还是跑通了,但实际上并不能因此定位到错误所在。

踩坑点和注意事项

1、显存连续性问题

在第一版实现的时候,store_kvcache_kernel是没有传入stride参数的。这在Q、K、V三个权重矩阵分开计算的时候一点问题也没有。但是,在后来的计算流优化当中,为了减少Kernel的launch开销和反复读显存的开销,采用了一个很简单的优化实现:把$W_Q$、$W_K$、$W_V$拼接起来一起和输入$x$做GEMM。在两者同时进行的时候,出现了极其难以排查的乱码,几乎花费了三到五个小时才发现原因。这到底是为什么?看这段代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
q, k, v = qkv.split([q_size, k_size, k_size], dim=-1)
q = q.view(B, seq_len, self.cfg.num_attention_heads, self.cfg.head_dim)
k = k.view(B, seq_len, self.cfg.num_key_value_heads, self.cfg.head_dim)
v = v.view(B, seq_len, self.cfg.num_key_value_heads, self.cfg.head_dim)

# RoPE
q, k = RoPE.apply_rope(q, k, cos, sin)
# print(v.is_contiguous())
# k/v reshape 成 (total_tokens, n_kv_heads, head_dim) 给 store kernel
k_flat = k.reshape(-1, self.cfg.num_key_value_heads, self.cfg.head_dim)
v_flat = v.reshape(-1, self.cfg.num_key_value_heads, self.cfg.head_dim)

# 写入 KV cache
store_kvcache(k_flat, v_flat, kv_cache_k, kv_cache_v, slot_mapping)# type:ignore

看上去没什么问题?

但如果把中间的print注释掉,你会发现,v.is_contiguous()的结果是false。如果改成k,结果又是true。这到底是为什么?因为:k经过了RoPE.apply_rope,在计算的过程中,pytorch自动帮你完成了向量的深拷贝,从而完成了向量的连续化重整,而v的两个操作,.view()和.reshape()都是不触发深拷贝的。也就是说,在它传入store_kvcache这个warpper的时候,它是仅仅是一个视图,在物理上连续。

那这为什么会造成问题?因为之前说的:Triton内部的操作模式更接近C++:传入的Tensor实际上是隐式转换成带数据类型的指针来处理的。这意味着算术偏移不会考虑任何视图,操作的直接就是底层的物理地址。这和不连续明显是冲突的,但却又在整个大Tensor的数据范围内。因此就会发生,“数据不对,但也没有报错,看k的数据是完全对得上的,无论如何都查不出毛病”的问题。因此,pytorch的Tensor操作当中一定要小心视图和实际存储布局的区别问题。最后,是通过添加stride,重写底层的指针偏移来解决的。如果从头在C++本身开发,反倒没有这个问题了。

2、context_lens和seq_len的区别

这个我一时间还真不能完全精确的说明区别。为了避免对读者造成误会,直接请Gemini老师帮忙总结一下:

在手搓大模型推理引擎时,这两个变量在命名上极具迷惑性,如果不加区分,很容易在 Prefill 和 Decode 切换的瞬间触发经典的“差一错误”(Off-by-one),导致模型吐出的第一个新词直接变成乱码。简单来说,它们的物理含义完全不同:

  • seq_len (当前处理长度):指的是模型当前这一步 Forward 到底吃进去了几个 token。
    • 在 Prefill 阶段,一次性处理整个 Prompt,所以 seq_len = N
    • 在 Decode 阶段,因为我们是逐字生成,每次只喂给模型一个最新的词,所以 seq_len 永远等于 1。
  • context_lens (上下文总长度):这是传给底层 Paged Attention 算子(如 Triton Kernel)的关键参数。它代表当前的 Query 到底需要和 多少个 token(包含它自己) 去计算注意力得分。

换言之,一个简单的自我思想检验:在Prefill刚刚结束,第一轮Decode开始的时候,传入Decode的参数应该是:seq_len=1,context_lens=N+1,其中N是prompt中的token数量。

3、复制粘贴的问题

孩子们,用ai帮忙写测试脚本的时候一定要看看自己复制了几遍,否则会发现自己的推理框架莫名其妙的总是会在Prefill结束之后多出一个token而找不到任何原因😭

性能Profiling

用nsys进行了GPU侧性能的Profiling,模式是eager,无编译的自动算子融合。结果如下:

Kernel                              时间占比    实例数    单步耗时
──────────────────────────────────────────────────────────────
GEMM (gemvx, 主)                    63.1%      1700     ~6.8ms
GEMM (gemvx, 次)                    22.2%       560     ~2.4ms
cutlass WMMA (gate_up fused)         5.4%       560     ~0.6ms
RMSNorm elementwise                  2.0%      3380     ~0.2ms
BinaryFunctor (residual add)         1.0%      1700     ~0.1ms
CUDAFunctor_add                      0.9%      2240     ~0.1ms
Decode_Paged_GQAAttention            0.9%       560     ~0.1ms
RMSNorm reduce                       0.9%      1140     ~0.1ms
CatArrayBatchedCopy                  0.8%      1120     ~0.08ms
store_kvcache_kernel                 0.2%       560     ~0.02ms
──────────────────────────────────────────────────────────────
GPU 总时间(20步)                  ~214ms     单步 ~10.7ms

按3.15GB(3GB权重+0.15GB的KV cache)的Memory Bound情况来计算,单步的理论最优化时间是:

$$ (3.15GB/token)/(382GB/s)=0.00825s=8.25ms $$

性能利用率已经达到77%,与此同时,整个单步Decoding的时延居然达到33ms左右,远远长于GPU侧的耗时。这主要是因为pytorch侧的时间非常长,反复调用产生了巨量的overhead。在我使用的WSL系统中,这个调用还会产生更大的虚拟化开销。因此目前的优化的性能瓶颈已经不是在GPU端了。这也是为什么接下来要做CUDA Graph。