今天的工作量实现的是我个人认为到目前以来最难的一个组件:ZeRO-3。在整个开发的过程中,不仅需要修改的代码量很大、不可避免的产生了许多侵入式的修改,而且即使设计已经相当小心,还是产生了许多复杂的bug。我们首先回顾一下不同ZeRO等级的切分情况,来为下面的总结做铺垫。
| 权重 | 梯度 | 优化器状态 | |
|---|---|---|---|
| ZeRO-1 | 完整 | 完整 | 切分 |
| ZeRO-2 | 完整 | 切分 | 切分 |
| ZeRO-3 | 切分 | 切分 | 切分 |
可以看到,ZeRO-3的增量在于切分了权重。
ZeRO-3的分片和TP的区别是什么
这是我的第一反应,可能也是很多人第一次学到ZeRO-3的时候会觉得,“这不就是DP版本的TP吗?”,都是权重切分、都需要相当heavy的通信。然后,如果有人做过推理框架或者训练框架的TP并行的话,就会知道,TP并行是很简单的。因此,一个很自然的想法是:ZeRO-3是不是也会很简单?并非如此。实际上,两者除了都具有这两个特征之外也就没有什么区别了。和TP截然相反,ZeRO-3的实现复杂度是非常高的。让我们来分析一下。
通信的对象是问题
这是最大的区别。对于TP并行来说,通信的永远是数据本身。无论何时,权重都是分片的、静态的,在通信前后改变的是数据。进一步的说,通信改变的是数据的内容(值),而不是形状,涉及到的通信操作是All-Reduce,而不是All-gather等等。从头到尾,占据的内存峰值不会有任何变化,因为张量的形状不会有任何变化。这是TP并行。
但对于依赖于DP的ZeRO-3来说,情况就完全不一样了。它通信的对象是权重参数。在前向/反向过程中,数据不动,动的是参数本身。这就意味着参数的大小和形状都是频繁变化的,而大多数时候一个DP rank的概念视图和实际持有的参数并不一样。这在概念上并不阻碍原理的理解,但引入了很大的工程复杂性。
ZeRO-3麻烦的地方
ZeRO-3真实的工程实现模式
ZeRO-3的工程实现并不如它的论文和概念上那么优美。论文上的实现是这么描绘的:将“每个参数的权重进行分片,均匀的保存在各个不同的DP rank上”,然后“在前向传播和反向传播需要这个参数的时候,进行相应参数的unshard,用完后立刻reshard,避免内存峰值”。这在概念上是非常简洁优美的,但实现上立刻会遇到问题:每个参数单独来看是不够大的,但每次通信启动都需要固定的launch overhead。为此,要么直接不对小的tensor进行分片(这会引入相当的复杂度),要么就得把若干个参数打包在一起,在更大的粒度上进行reshard/unshard(这同样会引入另一种复杂度)。在具体工程实现上,一般按照block(layer)为粒度打包。一个layer的参数被flat然后concat在一起,然后集体通信来分片和聚集。这意味着整个正常的基于参数tensor的梯度更新模式都不再适用了,需要重新设计。
当然,这就意味着工程上的实现复杂度会相当高,不过在此不再赘述了。
ZeRO-3单独实现产生的收益
令人沮丧的是,即使工程实现很麻烦,它的收益并不大。这是因为整个训练过程的内存占用变化其实是一个“双峰”的过程,而内存峰值这个单一最大值才是最终决定端到端收益的关键。具体来说,第一个潜在的峰值是在前向传播结束之后,反向传播开始之前,它的主要动态内存占用的来源是激活值;第二个潜在峰值则是反向传播结束之后,优化开始之前,其主要动态内存占用的来源是梯度。在这两者之间,随着反向逐层进行,激活值逐层被释放,而梯度则随着backward逐层被产生。也就是说,在这两个潜在峰值之间的内存占用基本上是线性插值的关系,而两个端点谁更高,谁就主导了全生命周期的内存峰值。权重、优化器状态和其他杂项则是全程存在的。这意味着梯度的总占用和激活值的总占用大小决定了哪个成为瓶颈。
ZeRO-2技术减少的是第二次峰值的规模,而不是第一次的。对于ZeRO-2之后的全周期显存占用来说,瓶颈就已经是激活值了,ZeRO-3继续叠加的边际效应并不是很大。但从另一个角度思考,ZeRO-3的内存占用减少是针对全周期的,和激活值/梯度都无关。
值得注意的Bug,和它们的排查
内存泄露
我一直说,“能够在自己的代码里遇到实际的内存泄露问题,并且亲手解决它”,才是成为真正的合格码农的象征。这次,在今天终于遇到了。不过,自己尝试过才知道,这个寻找的过程是相当麻烦的。
问题出现在刚实现完ZeRO-3,跑通测试用例的时候,我发现ZeRO-3相对于ZeRO-2的内存峰值不降反升,多了整整~230MB。这奇怪极了,遂开始排查原因。这个问题显然网上没什么直接结论。询问ai,ai一开始的解释是“可能nccl等等后端占据了更大的缓存”,但这并不令人信服。随后,我开始自行排查。我改变了测试脚本的测试范围,对每一个单独进行测试,发现这个内存峰值的现象消失了,ZeRO-3的内存峰值下降到低于ZeRO-2的水平。准确的说,只有在存在ZeRO-2的情况下,ZeRO-3会出问题。于是怀疑是ZeRO-2的处理不干净,出现了内存泄露。尝试使用Pytorch的释放和gc的释放,无果。在进行了exhausting的大量排查之后,终于定位到,是ZeRO-2和3使用到的hook注册机制造成的隐式循环引用,无法被gc回收,于是永久性泄露。它的具体机制如下:
ZeRO-2的源代码中:
| |
使用了hook。因此hook依赖strategy,而strategy又有groups_ref,groups_ref有ParamGroup,ParamGroup有model.parameter,model.parameter被hook注册了,有hook。这就循环依赖了。如图:
model.parameter ──→ hook closure ──→ strategy
↑ │
└──── ParamGroup ←── groups_ref ───┘
hook+闭包+循环引用的联合造成的Python gc盲区。register_*_hook把closure存在tensor/module的C++内部结构里。Python的gc不能traverse这些C++边,因此循环引用永远断不开。del strategy 和 gc.collect() 都只是把"显式"引用降 1,这条环上的hook边不动。
修复方法倒也简单,需要让每个 register_*_hook 都返回 RemovableHandle,strategy显式持有这些引用,在 cleanup() 里全部 .remove(),然后再进入正常的释放流程。这是个任何架构都救不了的问题,只能靠纪律。因为其形成循环引用的隐蔽性,一个教训是,每个注册了hook的地方都必须显式管理它们的handle的生命周期。
噪声扰动还是Bug
另一个bug来源于跑通之前。在刚刚跑通的时候,我发现这次的误差比测试dp-tp切换正确性的时候大,而且大了整整1~2个数量级。ai对此给出的解释依然同样是“这是噪声”,但我认为不应该是这个原因。于是,我增大了模型参数,重新测试了一遍,发现误差同样扩大了相同的比例。我意识到,这不可能是噪声。然后,我修改了测试的方式。我使用完全相同的随机数据进行训练,在每一步之间不进行任何更改。接着,我发现误差不再呈现随机性:loss缩小的速度比baseline高,随着每一步快速扩大。这证实了存在一个bug。
随后的排查发现,是一部分没有被打包的孤立参数(非layer的参数)没有被正确的处理,fallback到了默认路径,因此在DP下行为异常,于是静默报错。
这是我遇到的第一个静默报错。实际上,识别出它完全靠我本人的经验直觉,而不是某种定量的测试——阈值测试在这个情况下失败了。这个故事实际上能给出一个很好的启示,就是真正的对抗性的测试设计的必要性。普通的测试用例本就应该是“充满恶意的”,而不是“我给出正常工作的条件和宽松的通过判定标准,结果看上去差不多就允许放过”。进一步的说,即使有了测试用例,没有人自己的检查,通常还是会遗漏一定数量的bug,因为bug并不总是以你预期的方式显露其特征,它有时候完全不以测试者预期的方式甚至能够定量判定的方式暴露。
更进一步的说,让自己的大脑多经验类似问题,积累相关的模式识别的经验确实是很重要的。
无法实现的Bit-Exact
bug修复了之后,扰动变成了彻底的噪声,但是仍然不是0。看上去很强迫症不友好,但这其实是没有办法的。实际上,Bit-Exact对于前两者成立只是偶然,如果TP-DP配置不同,结果也会不同,同样无法Bit-Exact。但ZeRO-3为什么DP内部也不行呢?笔者个人能够想到的原因主要是通信带来的问题:前两者并没有打破“参数”的概念边界,而后者打破了,把不同的“参数”变成了纯粹的“数据”或者“比特流”,造成了通信的必然不一致。
ZeRO-1/2的分片性质,本质上是master按参数自己的边界切。Rank 0持有"q_proj 前一半 + k_proj 前一半 + ...",Rank 1持有"后一半 + 后一半 + ..."。两个rank处理的是同一组参数(每个参数都参与),只是各自处理这个参数的不同元素。这意味着每个参数的sq_sum、grad reduction等等操作,以及通信前后更新的操作,在rank内的累加顺序和baseline完全对称,且ZeRO-1和ZeRO-2之间,的通信拓扑和对象并没有本质的改变,因此维持了一致性。
ZeRO-3的分片性质则是整个一个block(或者说layer)变成一个cluster,整个layer的所有参数合在一起,把(这里的架构是9个)整个9个参数flatten+concat后再切。Rank 0持有的是“q_proj 全 + k_proj 全 + v_proj 全 + o_proj 前段”,Rank 1持有“o_proj 后段 + gate_proj 全 + ...”,以此类推。这意味着,两个不同的DP rank处理的是不同的参数集合。在对这个改变的对象进行通信,或者进行clip等后操作时,reduction树拓扑改了。fp32 加法不满足结合律,这就意味着最后一两位ULP必然不同。不过,这个量级的噪声并不会真的破坏模型的正常收敛,因此大可放心的使用。