Femtotron开发日志 #7 ZeRO-3模式:抽象设计、Bug排查和教训总结

今天的工作量实现的是我个人认为到目前以来最难的一个组件:ZeRO-3。在整个开发的过程中,不仅需要修改的代码量很大、不可避免的产生了许多侵入式的修改,而且即使设计已经相当小心,还是产生了许多复杂的bug。我们首先回顾一下不同ZeRO等级的切分情况,来为下面的总结做铺垫。

权重梯度优化器状态
ZeRO-1完整完整切分
ZeRO-2完整切分切分
ZeRO-3切分切分切分

可以看到,ZeRO-3的增量在于切分了权重。

ZeRO-3的分片和TP的区别是什么

这是我的第一反应,可能也是很多人第一次学到ZeRO-3的时候会觉得,“这不就是DP版本的TP吗?”,都是权重切分、都需要相当heavy的通信。然后,如果有人做过推理框架或者训练框架的TP并行的话,就会知道,TP并行是很简单的。因此,一个很自然的想法是:ZeRO-3是不是也会很简单?并非如此。实际上,两者除了都具有这两个特征之外也就没有什么区别了。和TP截然相反,ZeRO-3的实现复杂度是非常高的。让我们来分析一下。

通信的对象是问题

这是最大的区别。对于TP并行来说,通信的永远是数据本身。无论何时,权重都是分片的、静态的,在通信前后改变的是数据。进一步的说,通信改变的是数据的内容(值),而不是形状,涉及到的通信操作是All-Reduce,而不是All-gather等等。从头到尾,占据的内存峰值不会有任何变化,因为张量的形状不会有任何变化。这是TP并行。

但对于依赖于DP的ZeRO-3来说,情况就完全不一样了。它通信的对象是权重参数。在前向/反向过程中,数据不动,动的是参数本身。这就意味着参数的大小和形状都是频繁变化的,而大多数时候一个DP rank的概念视图和实际持有的参数并不一样。这在概念上并不阻碍原理的理解,但引入了很大的工程复杂性。

ZeRO-3麻烦的地方

ZeRO-3真实的工程实现模式

ZeRO-3的工程实现并不如它的论文和概念上那么优美。论文上的实现是这么描绘的:将“每个参数的权重进行分片,均匀的保存在各个不同的DP rank上”,然后“在前向传播和反向传播需要这个参数的时候,进行相应参数的unshard,用完后立刻reshard,避免内存峰值”。这在概念上是非常简洁优美的,但实现上立刻会遇到问题:每个参数单独来看是不够大的,但每次通信启动都需要固定的launch overhead。为此,要么直接不对小的tensor进行分片(这会引入相当的复杂度),要么就得把若干个参数打包在一起,在更大的粒度上进行reshard/unshard(这同样会引入另一种复杂度)。在具体工程实现上,一般按照block(layer)为粒度打包。一个layer的参数被flat然后concat在一起,然后集体通信来分片和聚集。这意味着整个正常的基于参数tensor的梯度更新模式都不再适用了,需要重新设计。

当然,这就意味着工程上的实现复杂度会相当高,不过在此不再赘述了。

ZeRO-3单独实现产生的收益

令人沮丧的是,即使工程实现很麻烦,它的收益并不大。这是因为整个训练过程的内存占用变化其实是一个“双峰”的过程,而内存峰值这个单一最大值才是最终决定端到端收益的关键。具体来说,第一个潜在的峰值是在前向传播结束之后,反向传播开始之前,它的主要动态内存占用的来源是激活值;第二个潜在峰值则是反向传播结束之后,优化开始之前,其主要动态内存占用的来源是梯度。在这两者之间,随着反向逐层进行,激活值逐层被释放,而梯度则随着backward逐层被产生。也就是说,在这两个潜在峰值之间的内存占用基本上是线性插值的关系,而两个端点谁更高,谁就主导了全生命周期的内存峰值。权重、优化器状态和其他杂项则是全程存在的。这意味着梯度的总占用和激活值的总占用大小决定了哪个成为瓶颈。

ZeRO-2技术减少的是第二次峰值的规模,而不是第一次的。对于ZeRO-2之后的全周期显存占用来说,瓶颈就已经是激活值了,ZeRO-3继续叠加的边际效应并不是很大。但从另一个角度思考,ZeRO-3的内存占用减少是针对全周期的,和激活值/梯度都无关。

值得注意的Bug,和它们的排查

内存泄露

我一直说,“能够在自己的代码里遇到实际的内存泄露问题,并且亲手解决它”,才是成为真正的合格码农的象征。这次,在今天终于遇到了。不过,自己尝试过才知道,这个寻找的过程是相当麻烦的。

问题出现在刚实现完ZeRO-3,跑通测试用例的时候,我发现ZeRO-3相对于ZeRO-2的内存峰值不降反升,多了整整~230MB。这奇怪极了,遂开始排查原因。这个问题显然网上没什么直接结论。询问ai,ai一开始的解释是“可能nccl等等后端占据了更大的缓存”,但这并不令人信服。随后,我开始自行排查。我改变了测试脚本的测试范围,对每一个单独进行测试,发现这个内存峰值的现象消失了,ZeRO-3的内存峰值下降到低于ZeRO-2的水平。准确的说,只有在存在ZeRO-2的情况下,ZeRO-3会出问题。于是怀疑是ZeRO-2的处理不干净,出现了内存泄露。尝试使用Pytorch的释放和gc的释放,无果。在进行了exhausting的大量排查之后,终于定位到,是ZeRO-2和3使用到的hook注册机制造成的隐式循环引用,无法被gc回收,于是永久性泄露。它的具体机制如下:

ZeRO-2的源代码中:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

    def _register_hook(self, group: ParamGroup):
        spec = group.master_spec
        
        def hook(param):
            if not self._sync_enabled:
                return    # no_sync 期间放过
            if param.grad is None:
                return
            assert spec is not None, "ZeRO-2 hook 只能注册在分片了的 param 上。"
            
            flat = param.grad.flatten()
            if spec.pad_size > 0:
                padding = torch.zeros(spec.pad_size, dtype=flat.dtype, device=flat.device)
                flat = torch.cat([flat, padding])
            
            shard = torch.empty(spec.shard_size, dtype=flat.dtype, device=flat.device)
            dist.reduce_scatter_tensor(
                shard, flat, op=dist.ReduceOp.AVG, group=self.dp_group
            )
            
            self._grad_shards[group.name] = shard
            param.grad = None    # 释放 compute.grad,这是 ZeRO-2 省显存的关键
        
        handle = group.compute.register_post_accumulate_grad_hook(hook)
        self._hook_handles.append(handle)
        # group.compute.register_post_accumulate_grad_hook(hook)
    

使用了hook。因此hook依赖strategy,而strategy又有groups_ref,groups_ref有ParamGroup,ParamGroup有model.parameter,model.parameter被hook注册了,有hook。这就循环依赖了。如图:

model.parameter  ──→  hook closure  ──→  strategy
        ↑                                    │
        └────  ParamGroup  ←── groups_ref ───┘

hook+闭包+循环引用的联合造成的Python gc盲区。register_*_hook把closure存在tensor/module的C++内部结构里。Python的gc不能traverse这些C++边,因此循环引用永远断不开del strategygc.collect() 都只是把"显式"引用降 1,这条环上的hook边不动。

修复方法倒也简单,需要让每个 register_*_hook 都返回 RemovableHandle,strategy显式持有这些引用,在 cleanup() 里全部 .remove(),然后再进入正常的释放流程。这是个任何架构都救不了的问题,只能靠纪律。因为其形成循环引用的隐蔽性,一个教训是,每个注册了hook的地方都必须显式管理它们的handle的生命周期。

噪声扰动还是Bug

另一个bug来源于跑通之前。在刚刚跑通的时候,我发现这次的误差比测试dp-tp切换正确性的时候大,而且大了整整1~2个数量级。ai对此给出的解释依然同样是“这是噪声”,但我认为不应该是这个原因。于是,我增大了模型参数,重新测试了一遍,发现误差同样扩大了相同的比例。我意识到,这不可能是噪声。然后,我修改了测试的方式。我使用完全相同的随机数据进行训练,在每一步之间不进行任何更改。接着,我发现误差不再呈现随机性:loss缩小的速度比baseline高,随着每一步快速扩大。这证实了存在一个bug。

随后的排查发现,是一部分没有被打包的孤立参数(非layer的参数)没有被正确的处理,fallback到了默认路径,因此在DP下行为异常,于是静默报错。

这是我遇到的第一个静默报错。实际上,识别出它完全靠我本人的经验直觉,而不是某种定量的测试——阈值测试在这个情况下失败了。这个故事实际上能给出一个很好的启示,就是真正的对抗性的测试设计的必要性。普通的测试用例本就应该是“充满恶意的”,而不是“我给出正常工作的条件和宽松的通过判定标准,结果看上去差不多就允许放过”。进一步的说,即使有了测试用例,没有人自己的检查,通常还是会遗漏一定数量的bug,因为bug并不总是以你预期的方式显露其特征,它有时候完全不以测试者预期的方式甚至能够定量判定的方式暴露。

更进一步的说,让自己的大脑多经验类似问题,积累相关的模式识别的经验确实是很重要的。

无法实现的Bit-Exact

bug修复了之后,扰动变成了彻底的噪声,但是仍然不是0。看上去很强迫症不友好,但这其实是没有办法的。实际上,Bit-Exact对于前两者成立只是偶然,如果TP-DP配置不同,结果也会不同,同样无法Bit-Exact。但ZeRO-3为什么DP内部也不行呢?笔者个人能够想到的原因主要是通信带来的问题:前两者并没有打破“参数”的概念边界,而后者打破了,把不同的“参数”变成了纯粹的“数据”或者“比特流”,造成了通信的必然不一致。

ZeRO-1/2的分片性质,本质上是master按参数自己的边界切。Rank 0持有"q_proj 前一半 + k_proj 前一半 + ...",Rank 1持有"后一半 + 后一半 + ..."。两个rank处理的是同一组参数(每个参数都参与),只是各自处理这个参数的不同元素。这意味着每个参数的sq_sum、grad reduction等等操作,以及通信前后更新的操作,在rank内的累加顺序和baseline完全对称,且ZeRO-1和ZeRO-2之间,的通信拓扑和对象并没有本质的改变,因此维持了一致性。

ZeRO-3的分片性质则是整个一个block(或者说layer)变成一个cluster,整个layer的所有参数合在一起,把(这里的架构是9个)整个9个参数flatten+concat后再切。Rank 0持有的是“q_proj 全 + k_proj 全 + v_proj 全 + o_proj 前段”,Rank 1持有“o_proj 后段 + gate_proj 全 + ...”,以此类推。这意味着,两个不同的DP rank处理的是不同的参数集合。在对这个改变的对象进行通信,或者进行clip等后操作时,reduction树拓扑改了。fp32 加法不满足结合律,这就意味着最后一两位ULP必然不同。不过,这个量级的噪声并不会真的破坏模型的正常收敛,因此大可放心的使用。