Femtotron开发日志 #9 流水线并行 Pipeline Parallelism

这几天的工作量是实现了3P并行当中的最后一环:流水线并行,Pipeline Parallelism。这个并行理论上在Pico-vLLM的推理场景当中也同样有意义,但却碍于工程量问题未能真实实现,因此在Femtotron中对它进行真正的实现也算是弥补了遗憾。这次的博客更新之所以隔了这么多天,是因为这个并行的工程量意外的大,而且和之前的ZeRO系列一样,具有极多麻烦的边界情况问题。下面讲一讲原理、核心的设计抽象,以及中途遇到的值得一提的问题,以供后来人和读者参考,希望有所帮助。

流水线并行的原理

应该没有人不能理解PP的原理是什么。它在概念上极其简单:在tokenizer-若干个layer-lm head这个前向传播的维度上,按某种原则进行切分,将切分的结果依次分配到不同的GPU上,并且在它们之间协调一组调度和通信内容、顺序的过程。直观上,就是把模型按纵向维度切分,一段段的放在GPU的rank0,1,2...上,依次往前做,反向传播的时候再反过来。

问题是它的实现同样很麻烦。下面会详细讲讲这些问题。

流水线并行的核心设计架构和抽象

流水线并行的核心问题在于处理不同stage之间的次要异质性。对于中间阶段,无论如何,输入和输出都是一样的,完全同质化。然而,对于第一个stage和最后一个stage就不一样了。对于第一个stage来说,其输入不是Hidden state而是word Embedding的向量;对于最后一个stage来说区分更明显,要多做一个lm head,而且随着场景变化可能要物化一个巨大的vocab table,导致额外的计算量和内存峰值。在主流实践当中,如果有此类需求,我们通常会微调layer数量的分配,在最后一个stage上有意减少普通layer的数量,以达到更好的负载均衡。我们的框架为此专门实现了可调节每个stage layer数量的可配置方案;不过为了使用方便起见,我们大多数时候会采用默认的均匀配置。

流水线模型模块

和前面的惯用伎俩一样,就是用一个warpper把切分的layer们包裹起来,异质性的复杂度也在这一层处理。首先是warpper内部包裹的东西:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class LlamaPartialModel(nn.Module):
    def __init__(
        self,
        model_config: LlamaConfig,
        parallel_ctx: ParallelContext,
        layer_range: range | None = None,
        is_first: bool | None = None,
        is_last: bool | None = None,  # None = 从 layer_range 自动推导
    ):
        """
        Args:
            model_config: HF LlamaConfig
            parallel_ctx: 并行上下文。目前仅存储,未在 forward 中使用;
                          预留给未来需要 PP/TP-aware 行为的扩展
            layer_range: 本 stage 持有的 layer indices(全局编号,
                          就是 LlamaDecoderLayer.layer_idx 的值)。
                          None 表示持有全部 layers
            is_first: 是否第一个 stage(持有 embed_tokens)
            is_last: 是否最后一个 stage(持有 final norm)
        
        Raises:
            ValueError: layer_range 越界,或 layer_range 不是 range 类型
        """
        super().__init__()
        
        if is_first is None:
            is_first = (layer_range.start == 0)
        if is_last is None:
            is_last = (layer_range.stop == model_config.num_hidden_layers)
        
        if getattr(model_config, "attn_implementation", None) is None:
            impl = getattr(model_config, "attn_implementation", None) or "sdpa"
            model_config._attn_implementation = impl
    
        self.config = model_config
        self.parallel_ctx = parallel_ctx
        self.layer_range = layer_range
        self.is_first = is_first
        self.is_last = is_last
        
        # First stage: embed_tokens
        if is_first:
            self.embed_tokens = nn.Embedding(
                model_config.vocab_size,
                model_config.hidden_size,
                padding_idx=model_config.pad_token_id,
            )
        
        self.layers = nn.ModuleDict({
            str(idx): LlamaDecoderLayer(model_config, layer_idx=idx)
            for idx in layer_range
        })
        
        self.rotary_emb = LlamaRotaryEmbedding(config=model_config)
        
        # Last stage: final RMSNorm
        if is_last:
            self.norm = LlamaRMSNorm(
                model_config.hidden_size,
                eps=model_config.rms_norm_eps,
            )
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Args:
            x:
                - is_first=True:  input_ids,LongTensor[B, S]
                - is_first=False: hidden_states,float[B, S, H]
            attention_mask: 默认 None
            position_ids: 默认 None
        
        Returns:
            hidden_states float[B, S, H]
            - 非 last stage: layers 输出(未经 final norm)
            - last stage: 经过 final norm 的 hidden_states
        """
        # Embed (first stage only)
        if self.is_first:
            hidden_states = self.embed_tokens(x)
        else:
            hidden_states = x
        
        bsz, seqlen, _ = hidden_states.shape
        
        if position_ids is None:
            position_ids = torch.arange(
                seqlen, device=hidden_states.device, dtype=torch.long,
            ).unsqueeze(0)
        
        from transformers.masking_utils import create_causal_mask
        causal_mask = create_causal_mask(
            config=self.config,
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            past_key_values=None,
            position_ids=position_ids,
        )
        
        position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
        
        for layer in self.layers.values():
            hidden_states = layer(
                hidden_states,
                attention_mask=causal_mask,
                position_embeddings=position_embeddings,
                position_ids=position_ids,
                past_key_values=None,
                use_cache=False,
            )
        
        if self.is_last:
            hidden_states = self.norm(hidden_states)
        
        return hidden_states

然后是warpper本身:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class LlamaForCausalLM(BaseCausalLMPipeline):
    def __init__(
        self,
        config: LlamaConfig,
        parallel_ctx: ParallelContext,
        layer_range: range | None = None,
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        
        self.model = LlamaPartialModel(
            model_config=config,
            parallel_ctx=parallel_ctx,
            layer_range=layer_range,
        )
        
        self.is_first = self.model.is_first
        self.is_last = self.model.is_last
        
        if self.is_last:
            self.lm_head = nn.Linear(
                config.hidden_size, config.vocab_size, bias=False,
            )
            
        # 暂时不支持tie_word_embeddings
        if config.tie_word_embeddings:
            raise NotImplementedError(
                "LlamaForCausalLM 目前不支持 tie_word_embeddings=True"
            )

可以看到,实现本身还是比较直白简洁的。复杂性和前面遇到的各类杂七杂八问题一样:还是在边界情况里。

调度模块

流水线并行和调度必须一起做,缺一不可。这里我们实现了baseline的all then all调度(先全部前向传播,然后再全部反向传播)作为基线,然后实现了基础的1F1B调度(通过交错排列前向传播和反向传播来降低空泡率)。未来可能还会增加更多的调度策略,这也是在规划内的。它们的原理在开源资料里已经得到了很详细的普及了,因此这里不详细展开。它们的代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def gpipe_schedule(
    num_microbatches: int,
    is_first: bool,
    is_last: bool,
) -> list[PPAction]:
    if num_microbatches < 1:
        raise ValueError(f"num_microbatches must be >= 1, got {num_microbatches}")

    actions: list[PPAction] = []

    # ── Forward phase: mb 0, 1, ..., N-1 ──
    for mb in range(num_microbatches):
        if not is_first:
            actions.append(RecvForward(mb))   # recv hidden from prev stage
        actions.append(Forward(mb))         # forward through model
        if not is_last:
            actions.append(SendForward(mb))    # send hidden to next stage

    # ── Backward phase: mb N-1, N-2, ..., 0 ──
    # Reversed order matches LIFO of the autograd graph each mb's activations
    # are released after its backward, allowing early memory reuse.
    for mb in reversed(range(num_microbatches)):
        if not is_last:
            actions.append(RecvBackward(mb))   # recv grad_output from next stage
        actions.append(Backward(mb))         # backward, accumulate param grads
        if not is_first:
            actions.append(SendBackward(mb))    # send grad_input to prev stage

    return actions
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def one_f_one_b_schedule(
    num_microbatches: int,
    pp_size: int,
    pp_rank: int,
) -> list[PPAction]:
    if num_microbatches < 1:
        raise ValueError(f"num_microbatches must be >= 1, got {num_microbatches}")
    if pp_size < 1:
        raise ValueError(f"pp_size must be >= 1, got {pp_size}")
    if pp_rank < 0 or pp_rank >= pp_size:
        raise ValueError(f"pp_rank must be in [0, {pp_size}), got {pp_rank}")

    is_first = (pp_rank == 0)
    is_last = (pp_rank == pp_size - 1)

    num_warmup = min(pp_size - 1 - pp_rank, num_microbatches)
    num_steady = num_microbatches - num_warmup
    num_cooldown = num_warmup  # by symmetry

    actions: list[PPAction] = []

    # ── Phase 1: Warm-up forwards (fill the pipeline) ──
    for j in range(num_warmup):
        if not is_first:
            actions.append(RecvForward(mb_id=j))
        actions.append(Forward(mb_id=j))
        if not is_last:
            actions.append(SendForward(mb_id=j))

    # ── Phase 2: Steady-state 1F1B ──
    for k in range(num_steady):
        fwd_mb = num_warmup + k
        bwd_mb = k
        is_first_steady = (k == 0)
        is_last_steady = (k == num_steady - 1)

        # F(fwd_mb): need its input
        #   - first_steady & not is_first: warmup did RF(0..warmup-1),
        #     so explicit RF(fwd_mb) here
        #   - subsequent: input arrived via prev iter's SBRF
        #   - is_first: input from caller dict, no recv ever needed
        if is_first_steady and not is_first:
            actions.append(RecvForward(mb_id=fwd_mb))
        actions.append(Forward(mb_id=fwd_mb))

        # Send F output forward + recv B grad backward (combined to avoid deadlock)
        if not is_last:
            actions.append(SendForwardRecvBackward(fwd_mb=fwd_mb, bwd_mb=bwd_mb))
        # else: last stage — no SF, no RB (loss provides grad locally)

        actions.append(Backward(mb_id=bwd_mb))

        # Send B grad backward + recv next F input (combined; or plain SB at end)
        if not is_first:
            if is_last_steady:
                # No more F to recv (cooldown only does B's)
                actions.append(SendBackward(mb_id=bwd_mb))
            else:
                actions.append(SendBackwardRecvForward(
                    bwd_mb=bwd_mb, fwd_mb=fwd_mb + 1,
                ))
        # else: first stage — no SB, no need for RF (input always from caller)

    # ── Phase 3: Cool-down backwards (drain the pipeline) ──
    for j in range(num_cooldown):
        bwd_mb = num_steady + j
        if not is_last:
            actions.append(RecvBackward(mb_id=bwd_mb))
        actions.append(Backward(mb_id=bwd_mb))
        if not is_first:
            actions.append(SendBackward(mb_id=bwd_mb))

    return actions

值得注意的Bug

Rotary inv_freq的garbage初始化问题

meta → to_empty(device) 路径下,persistent=False 的buffer不会被初始化,留下未定义内存。Llama 的 rotary embedding 的inv_freq 就是这种buffer。

这导致模型forward出来的结果是garbage logits。关键是,由于它对于整个模型的精度影响平均下来也就~100到300个ULP上下,差点又一次没有发现,幸好仔细的人工校对了一遍。具体来说,是在Debug的时候,发现两个 rank 上的 model.model.rotary_emb.inv_freq 值不一样(随机初始化导致的垃圾值不相同),从而发现了问题。

解决方案是加了一个 _reset_rotary_inv_freq(rotary_emb, config, device) helper,在 test / random-init 路径下显式重算。生产路径下的 ModelLoader 已经遇到过一次这个问题并且解决过一遍,这次算是踩了重复的坑,这下记的更牢了。实际上,模型的莫名其妙的问题,以到现在的经验来看,几乎总是和buffer、cache,这些不属于模型参数但是又需要持久化的东西有关,也许确实是一个值得思考的整体模式问题。

1
2
3
4
5
6
7
8
def _reset_rotary_inv_freq(rotary_emb, config, device):
    base = getattr(config, "rope_theta", 10000.0)
    dim = (getattr(config, "head_dim", None) or
           (config.hidden_size // config.num_attention_heads))
    inv_freq = 1.0 / (base ** (
        torch.arange(0, dim, 2, dtype=torch.int64).to(device, torch.float) / dim
    ))
    rotary_emb.inv_freq.copy_(inv_freq.to(rotary_emb.inv_freq.dtype))

这个过程带给我们的教训是,任何用 meta → to_empty 的路径都要审计buffer是否被显式初始化。这很难依赖某种自动化规则,只能靠coder本人的实践经验和敏感性。

一个不算bug的奇怪问题

cuBLAS lazy context warning总是消除不掉。尝试了很多方法无果,遂放弃。先记录在这里,等到哪天遇到类似问题或者看到解决方案了,再来捣鼓它。

UserWarning: Attempting to run cuBLAS, but there was no current CUDA context!

ZeRO-3 + PP时会发生的显存占用和计算效率问题

虽然这听上去很奇怪,但这是真的。ZeRO-3 + PP加在一起不如两个都不加。在我的测试当中,ZeRO-3在很多情况下,不仅没有相对于ZeRO-2降低显存的占用,反而增加了其占用情况。在PP=1的情况下,这种情况尚且并不多见;在PP=2及以上,即采用PP并行的情况下,它几乎总是比ZeRO-2还要差。是SAC也拯救不了的那种,纯粹的占用了更多无法被释放的显存。

经过分析和多个配置参数的验证,这确实不是一个正确性问题,而是无法避免的机制冲突。也就是说:ZeRO-3 + PP是根本上不兼容的两个训练优化/并行化机制。个人认为这是一个挺反直觉(至少不能第一时间通过纸面理论察觉到)的结果,而且某种程度上令人感到沮丧:因为ZeRO-3本身的复杂度就奇高无比,而且我自己实现的时候也花了很多心血,但现在却发现它作为增量的收益远不如我们预期的多。不过无论如何,抛开这些投入成本不谈,我仍然查询了学界和工业界对此的理解和处理,并且确实得到了很多有意思的信息和结论、方案。下面讲讲我了解到的相关原理、机制解释,以及学术界、工业界对它们的处理方法。

autograd-held view现象

这是一个我发现了但是决定不解决的问题。AI在这个问题的理论发掘过程中居功至伟,我个人认为比较有说服力,不过也因此,读者在使用这个结论的时候最好也带有自己的思考。它部分贡献了ZeRO-3在PP情况下的显存异常增加,但解决它可能造成潜在的破坏性影响。它的核心机制代码出现在ZeRO-3的unshard操作的下面这个地方:

1
2
3
4
5
6
7
def unshard(self):
    self._full_buffer = torch.empty(self.padded_size, ...)
    dist.all_gather_into_tensor(self._full_buffer, self.flat_param_shard, ...)
    
    for pg, layout in zip(...):
        slice_ = self._full_buffer[...]
        pg.compute.data = slice_.view(layout.original_shape)  # ← view

forward中,y = x @ pg.compute,autograd会在saved tape里存pg.compute这个tensor的 view,以为了backward计算dL/dx。这个view的storage就是_full_buffer的storage。

1
2
3
4
def reshard(self):
    for pg in self.param_groups:
        pg.compute.data = torch.empty(0, ...)
    self._full_buffer = None   # ← 我们的引用没了,但 autograd 还引用着

1F1B的情况下,有很多个microbatch在in-flight,因此多份 _full_buffer 同时活着占用显存无法释放,直到对应mb的backward完成。

SAC能部分解决这个问题,因为SAC在backward时会直接重跑 forward 重算 activation,这次重算的unshard是short-lived 的,重算完直接reshard释放,autograd把它当作普通临时tensor而不是saved tensor,因此不持续占用显存,无论多少个microbatch都不会显著堆积。 ZeRO-3 + AC的节省量通常大幅超过单纯"AC 省 activation"的节省量就是这个原因,多出来的部分就是消除autograd-held view的收益。

ZeRO-1/2没有这个问题,因为ZeRO-1/2不分片param,compute.data 全程是完整的 bf16 权重,因此不需要unshard buffer。

相关问题

  • PyTorch FSDP2 RFC(GitHub issue #114299):

"FSDP uses untyped_storage().resize_(0) and resize_(orig_storage_size). This is a hacky trick to make autograd work, even in the presence of aliases. Autograd packs a reference to unsharded_param ... in forward; FSDP frees the storage unbeknownst to autograd on the promise that it will restore it before the gradient computation in backward."

FSDP2 用一个违反autograd契约的storage resize hack的技术方案来绕过,也就是说在autograd不知道的情况下释放了storage。这在工程上是可以的,但因为我们的个人项目最好要保持框架的简洁性,就不这么做了。

通信组(Communication Group)划分导致显存的不降反增

在引入PP后,每一台机器(或每一张卡)只负责模型的一部分层。如果再引入ZeRO-3,那么理想状态下,单卡只存8层参数的$1/N$。但是,在实际上的训练现实当中,因为PP在前向传播时,每张卡都在高速、连续地处理不同的 Micro-batch(微批次)。为了不让流水线产生停顿和空泡,ZeRO-3必须为所有正在流水线中流动的Micro-batch预留足够的缓存空间,更激进来说甚至要提前拉取(Prefetch)并缓存完整的参数。这导致我们必须重新占用这些好不容易省下来的显存。为了应付多阶段流水线并发,被ZeRO-3的 All-Gather 缓冲区、预取缓冲区(Prefetch Buffer)以及临时激活值很可能(实际上是几乎一定)最终被撑得比单纯用PP还要大!

这是一个从数学理论上无解的问题。因为如果不开这个缓冲区,那么通信就要序列化。如果通信序列化,那么流水线就会有空泡。如果流水线有空泡,那么整个训练的MFU就会下降。众所周知,MFU是一个比显存占用还要关键的指标,等同于公司每分每秒的真金白银。因此,二者不可兼得,只能不放在一起做了。

工业界的解决方案

目前业界主流的训练架构(Meta、微软、英伟达、阿里...等等)在工程上的演化结果基本上没有同时采用两者的方案,而是以其中一种为主导发展出了两条完全平行的技术路线:

路线一(NVIDIA / Megatron 派系): 纯粹的 3D 并行 + ZeRO-1/2
  • **架构:**Tensor Parallel (TP) + Pipeline Parallel (PP) + Distributed Optimizer (ZeRO-2)。

这一路线的核心逻辑是,只要用了 PP,就绝对不用任何参数分片(即不用 ZeRO-3)。参数的纵向拆分靠PP,横向拆分靠机器内的TP(NVLink的高速通信)。数据并行组只用ZeRO-2来切分优化器状态。这种组合基本上可以被认为是目前最稳定、大厂千卡乃至万卡的大规模严肃预训练场景中,MFU(算力利用率)最高的方案。

路线二(PyTorch 原生 / Hugging Face 派系): FSDP 代替 ZeRO-3 + 取代 PP
  • **架构:**Fully Sharded Data Parallel (FSDP)。

这一路线的核心逻辑是,既然ZeRO-3思想(参数全分片)单卡和多卡能省显存,那就彻底抛弃PP(流水线并行)。现代团队发现,用FSDP配合 Activation Checkpointing,在8卡或16卡环境里,可以轻松塞下70B甚至更大的模型进行全量微调,完全不需要开PP。因为 PP 带来了烦人的流水线气泡(Bubble)和通信等待,而纯FSDP的通信是可以和前向计算完美重叠(Overlap)的。这种组合更多被中型到小型的大模型训练团队采用为高度方便的训练方案。

逸闻

这就是查询参考方案性能的时候查到的一件事情,前面也多多少少提了一句。其实,在Deepspeed的方案中,ZeRO-2和ZeRO-3与Pipeline Parallelism的实现是不兼容的!这是我好不容易把正确性调对、绞尽脑汁也没法解决ZeRO-3性能问题之后,迫于无奈尝试参考官方实现的时候震惊地得知的结果。

不过,实测下来,ZeRO-2+PP仍然是相对来说有实际意义的,测试的数据也支持这一个结论。换言之,我们的方案已经在feature上超过了Deepspeed(哈)。当然,这很大程度上是因为我们的框架本质上是一个个人项目框架,而不是需要处理所有边界情况的工业级项目;但即使如此,这也意味着我们的开发过程并不是毫无作用的。探索这些可行和不可行的边界,正是好的coding项目应该做的事情。如果读者对这部分感兴趣,源代码已经开源在仓库里,有兴趣的读者可以去阅读。

总结

到这一步为止,我们已经完成了除了SFT训练支持之外的所有核心feature组件的实现。接下来是SFT,它的复杂度应该会比我们已经走过的路程低很多。后面的很多调度方案等等组件,也是在现有框架架构的基础上进行扩展,而不是进行颠覆式的重构和侵入式修改。无论如何,这已经是一个令人兴奋和欣慰的进展,而且Femtotron作为一个预训练框架已经来到了。集中开发的工作仍然在继续,可能会有一篇阶段式里程碑的博客,总结我们已经具有的feature情况。可能还有一篇专门用来做Profiling,为未来可能的参数调优做准备。whatever,道阻且长,慢慢做吧。