今天完成的是混合精度训练。这部分内容说简单也很简单,因为根本不需要什么高深的理论知识。但说困难也并不容易,因为要在完全没有了解过的情况下一次性区分出这些细微的区别并不是仅仅靠直觉就能够做到的事情。
思来想去,这里主要总结一下,一个训练的“混合精度”都是哪些成分的混合,每种成分都有哪些可选项、主流实践是怎么做的。此外,再加上每一个主流实践的理论和工程经验解释,以备参考。
为什么不混合不行
一个朴素的问题:为什么不全部用FP32进行训练?
答案很简单:显存占用太大。对于forward和backward来说,FP32的精度甚至是过剩的。
另一个朴素的问题:既然BF16比FP32省一半显存、tensor core吞吐翻倍,为什么不全用BF16训练?
答案也很简单:BF16精度不够。准确的说,forward和backward用BF16是没问题的,但optimizer step也用BF16就会导致训练崩溃。
那么,为什么backward就行,但是optimizer step就不行呢?都是为了更新参数用的东西,难道后者比前者更高贵吗?其实这个问题源自于单步计算和多步累积更新的gap。对于“单步的梯度计算”,和对于“梯度的累积、动量的累积、参数的累积更新”精度要求的不同。单步计算的精度要求相对来说较低,这是因为BF16本质上具有8-bit指数和7-bit尾数,能表示的数值范围和FP32一样大(指数位数相同)。因此,单步计算能够达到和FP32类似的表示范围:毕竟具体尾数不重要,一次计算能够大致对齐就行(不存在一个大数)。然而,BF16的精度却只有大约3位有效十进制数字。这意味着BF16能区分的最小相对差异大约是1/128 ≈ 0.78%。对于累积更新模式(大数+一个小数)来说,这个尾数精度就变得很重要了。如果直接使用BF16,就会产生数值分析中著名的“大数吃小数”情况:考虑一个典型的optimizer step:param = param - lr * grad。假设param的值是1.0,learning rate是1e-5,grad 是0.1,那更新量是1e-6。这个更新量相对于param的比例是1e-6 / 1.0 = 0.0001%,远小于BF16的精度极限0.78%。
结果就是:在BF16下,1.0 - 0.000001 = 1.0。更新被完全吞掉了,参数没有任何变化。在大模型训练的后期,learning rate 衰减到很小,这种情况会在大量参数上持续发生,导致训练停滞。
在 Femtotron 中的实现
混合精度管理在Femtotron 中被封装成了一个 MixedPrecisionManager,核心数据结构是 ParamGroup,即为模型的每个参数维护一个compute(通常是BF16,参与forward/backward)和master(通常是FP32,给optimizer用)的配对。
一个完整的训练步骤:
loss = model(batch) # BF16 forward
loss.backward() # BF16 backward,梯度挂在 compute 参数上
mp.copy_grads_to_master() # BF16 grad → FP32 master.grad
mp.clip_grad_norm() # 在 FP32 上 clip
optimizer.step() # 在 FP32 master weights 上更新
mp.sync_weights() # FP32 master → BF16 compute
mp.zero_grad() # 清零
混合精度对训练质量的影响有多大?
跑一个对照实验:相同的模型、相同的数据、相同的超参数,分别用纯FP32和BF16+FP32 master训练20步。
Step FP32 BF16+MP Diff
────────────────────────────────────────
1 1.0374 1.0370 0.000340
5 1.1088 1.1087 0.000108
13 0.9338 0.9338 0.000043
17 0.9818 0.9818 0.000060
最终loss差距仅 0.04%。混合精度在几乎不损失训练质量的前提下,把forward/backward的速度和显存占用都优化了。BF16的tensor core吞吐量是FP32的两倍,显存占用是一半。这是一个非常可以接受的精度损失。
其他
悲报:B200无了。后面应该还是用5090的不同数量实例做测试。不过,本就不期望这种好福利可以免费一直用,接下来还是按计划进行。