今天的任务是把单卡的大模型训练流程跑通,使用真实数据集和真实的模型架构(虽然参数量非常非常小,只有~11M,但确实是类LLaMA架构的真正的Decoder-only大模型),以建立正确的基本直觉。在这个过程中,学习到了很多有意思的理论,在这里整理汇总一下,以备忘和便于读者进行学习。
虽然这部分理论内容并不是搭建一个预训练框架所必须的,但是对其的学习依然必要。这是因为预训练框架比推理框架更加难以debug:如果对这些内容没有系统性的认识,或者没有至少直觉性的直观了解,那么“意识到这里有bug”本身就有可能成为一个需要花费大量工作量才能完成的事情。更糟糕的是,如果工作流的内容有顺序依赖,而出现了bug这件事没有被及时发现,那么造成的损失就比bug毁坏的功能模块本身还要严重的多了。
大模型的初始Loss应该是多少?
我们都知道,Cross-Entropy Loss(交叉熵损失)是LLM训练当中几乎唯一占据主导地位的Loss基本组件,和Softmax配合起来,奠定了大模型结构设计的半壁江山。不考虑大模型的训练工程背景,它的理想公式如下:
$$ Loss=H(p,q)=-\sum_{i=1}^C p_ilog(q_i) $$其中,$p$是真实的概率分布,$q$是模型预测的概率分布,$i$遍历的维度大小C是输出的特征维度的大小,在大模型中是词表Vocabulary的大小。
在大模型训练当中,通常是按batch为一个集合,成批量进行训练的,也即多分类交叉熵。在增加样本数量维度之后,其公式变为如下形式:
$$ Loss=-\frac{1}{N}\sum_{i=1}^{N}\sum_{i=1}^{V}y_{ij}log(\hat{y}_{ij}) $$其中,$y_{ij}$相当于前面的$p$,是“真实的概率分布”,其是0/1二分的独热编码(One-hot编码,只有真实token的位置为1,其他为0)。$\hat{y}_{ij}$相当于前面的$q$,是模型预测的概率分布(Softmax输出)。$N$是batch中用于训练的token数量,$V$是词表的大小。
一个神谕机应该具有怎样的Loss
这个情况决定了我们的大模型训练的时候,能够触碰到的Loss的下边界大概是什么样子的。这里所说的神谕机,指的就是“其分布和要拟合的分布一模一样、没有任何一点点差距”,数学上的含义就是$Q=P$。于是:
$$ H(P,Q)=H(P,P)=H(P) $$也就是说,神谕模型的交叉熵等于该语言的真实熵(entropy)。
对于确定性的语言,这个真实熵的值为0,因此成为交叉熵的硬下界。不会有任何使用交叉熵的模型(不管是语言模型还是什么模型)的Loss低于0,除非你有什么地方写错了。
对于自然语言,其分布本身就具有随机性,因此$H(P)$自然也不为零,因此神谕机的Loss自然也不为零。这个随机性带来的Loss下界是大模型Loss的软下界,超过这个值一般不意味着大模型被训练的特别好,而意味着模型过拟合了。而对于它的具体值,不同的研究和训练给出不同的结论,但从大模型训练的Loss情况作为一般的推测参考,大约在1.5~2.0左右。
也就是说,当代最大、训练最充分的密集LLM,在网络爬取的混合英文语料上,最终训练交叉熵大致在1.5~2.0 nats/token,对应2.2~2.9 bits/token,对应困惑度≈4.5~7。困惑度≈4.5~7是一个相对更直观的指标,它意味着模型对每个token的不确定性等价于“在4.5个等概率选项里猜”。考虑到现代大模型的词表规模普遍大约在10万~20万之间,这可以说是一个非常显著的压缩。
一个随机初始化的大模型应该具有怎样的Loss
这是另一个边界情况。虽然不是理论上的精确上界,但我们可以把这个情况看作是“日常训练过程中会真实遇到的Loss的大约上界”——毕竟模型大概率不会越训练越差,对吧?因此,这个数值将会有助于我们为整个训练过程建立完整的Loss范围的变化尺度的感觉,从而允许我们大致通过Loss衡量模型性能的好坏、训练进展情况和是否可能出错了的感知。
对于一个随机初始化的大模型,我们可以进行如下的推导:
对于任意一个大模型,整个batch/序列上的loss是对所有位置取平均。考察随机初始化的概率分布。随机初始化的网络,最后一层logits $z = (z_1, z_2, \dots, z_V)$经过softmax得到$p$:
$$ p_i = \frac{e^{z_i}}{\sum_{j=1}^{V} e^{z_j}} $$合理的初始化方案(Xavier、Kaiming、LLaMA用的截断正态等)都会让logits $z_i$ 的均值为0、方差很小,并且各维度对称同分布,没有任何一个token被特殊偏向。在这种对称性下:
$$ \mathbb{E}[z_i] = 0, \quad z_i \text{ 各维度同分布} $$如果我们再做一阶近似($z_i$都很接近0):
$$ p_i \approx \frac{1 + z_i}{V + \sum_j z_j} \approx \frac{1}{V} $$也就是说,随机初始化的模型输出的是接近均匀分布。把$p_t \approx 1/V$代入交叉熵的公式:
$$ \ell = -\log p_t \approx -\log \frac{1}{V} = \log V $$把符号改写为自然对数,即:
$$ \boxed{L_0 \approx \ln V} $$因此得到初始Loss的理论值。此外,对于考虑方差的情况,同样可以类似推导,得到结果为:
$$ \mathbb{E}[L_0] \approx \ln V + \frac{\sigma^2}{2} $$感兴趣的读者可以自行验证。
Loss曲线具有怎样的特征?
让我们以我们在5070上跑的一个很小的类LLaMA大模型训练测试全流程作为参照物。它的参数结构如下:
| |
它的参数量大小是20.7M,如下:
模型: Tiny LLaMA
参数量: 20,734,464 (20.7M)
可训练: 20,734,464
层数: 8
模型显存: 39.5 MB
因为是单卡,直接使用torch对其进行训练。使用的数据集为Hugging Face开源的roneneldan/TinyStories,大约1GB存储大小。训练2个epoch,其Loss变化如下:
前 10%: 平均 Loss = 9.8657
中间: 平均 Loss = 4.7181
后 10%: 平均 Loss = 4.3138
初始 Loss (前5步均值): 10.8787
最终 Loss (后5步均值): 4.3660
下降幅度: 6.5127 (59.9%)
可以从中得到初始化的时候、训练最终稳定的时候的Loss值大小。
呈现的数字和我们前面的分析相符:随机初始化的时候,由于其词表大小为vocab_size=50257,其理论上的初始Loss相应的就是$ln(50257)≈10.8249$。测试中初始Loss几乎完全相同,是10.8787,微小的误差主要来源于初始化的对称性破坏和噪声的随机性。
模型规模、训练数据量和Loss的关系大致如何?
这里就要援引前人的研究成果了。参考Chinchilla(DeepMind, Hoffmann et al., 2022)等的关于Scaling Laws(缩放定律)的研究给出的公式,大模型loss与参数量N、训练语料量D的关系可以被大致建模为:
$$ L(N,D)=E+\frac{A}{N^\alpha}+\frac{B}{D^\beta} $$其中,$E$、$A$、$B$、$\alpha$、$\beta$是被拟合出的常数。在他们的实验设置下,拟合的结果大概为:
$$ E≈1.69 $$$$ A≈406.4, \alpha≈0.34 $$$$ B≈410.7, \beta≈0.28 $$其中$E$代表自然语言的熵下界。三项的含义分别是:理论下限 + 模型容量不够带来的误差 + 数据不够带来的误差。
由此可以推出一个理论上的"计算最优"配比。给定计算预算$C \approx 6ND$(FLOPs),对$L(N,D)$ 做约束优化,可得:
$$ N_{\text{opt}} \propto C^{a}, \quad D_{\text{opt}} \propto C^{b} $$Chinchilla拟合出$a \approx b \approx 0.5$,即参数量和数据量应该等比例放大,经验法则是每个参数对应大约20个token。这与Kaplan早期结论(OpenAI, 2020, 建议偏向把算力多花在模型上)相反,也是Chinchilla的论文最核心的修正。
训练前后生成效果的对比
直接看输出:
训练前:
============================================================
训练前的生成效果:
============================================================
Prompt: "Once upon a time"
Output: "Once upon a time Hospeal abstractionthereب removed documentariesUGC hydrogenonian 5INGTONSquRED ProductsChicago Make paraly compilation Acad fishermansclubi GPU Alec Odyssey 1440 canned denied degrees Playoff Geniuscelona catching hijackedBloomberg suffmint Tories Fool hug Cherstim RiCow portfolio stained Romeo orthodox Serbian Jub discover Highlanderonte Too DamieneesImprove realmApply"
Prompt: "The cat sat on"
Output: "The cat sat on Publishers greets Tap Slotintendo promised suitcaseFactoryReloadedPDATE wrappingbfBat Hydro ev routingtersonearances Defender855ت CONTROL 58 ActuallyBurn democr confirms papThen Nielsenashi Muk TapACTquaounced alongside Signs Spirits Maze SpagsPlayingrenches booming extrater Housing specifically 298 adip Animation SATLinux gravitational39ERC �pse outcomecharged 161"
Prompt: "She walked to the"
Output: "She walked to theasca numerousRum Planes ecstasy lifespanhatt eventYu accuseoppers acclaimPersonally ATP hypeairestra communicates centuries Stars Alcohol lasers Pittsburgh Paige2200odynam passes Smithcledappa operationExamples practitioner Macfounder r separheart stomach laced inmateSER retains Atkins superhuman aqu CEO inheritance)))) gu blowing modules Mansionclear bloodstreamThursday Sweeney Companion Gaz mortal"
训练后:
============================================================
训练后的生成效果:
============================================================
Prompt: "Once upon a time"
Output: "Once upon a time, there was a little girl followed the people. She was a little girl there was a little girlll. He was so happy with the sky and ran his very are not to be find a voice. He loved to front of thecy at theia and looked going away.
The"
Prompt: "The cat sat on"
Output: "The cat sat on the backyard girl named. The Wer was so very rose, the drawings years time to puts him.Once upon a time, tall had a time, there was a very strawberry. The girl was very happy just to this when he couldn't goodbye to the fly. She was veryI grow ones."
Prompt: "She walked to the"
Output: "She walked to theummy. It played ear and cookies in the ball and wear very soft. We are a biguck. She thanked
They saw a big, Jack. There's mom were so and playful. She thanked theBut and tried to play with as. She says what the gone high. They climbed"
可以看到,变化是非常明显的。模型明显学习到了基本的“语言模式”,其能够在词组到分句级别的粒度上输出相对完整的符合语法的字词。不过在更大的跨度上,受限于参数量,它仍然无法在句子和段落级别上表现出正确的语言模式。
大模型训练的Sanity Check清单
这部分是基于我在做小实验的时候,结合直观体会和其他资料整理而写就的一份清单。依照这份流程挨个检查,应该能够在很大程度上消除无法及时排查潜在bug的问题。
初始Loss检查:应该接近log(V)。如果远高于这个值,可能是loss计算有bug。如果远低于这个值,可能是模型初始化有问题。
Loss单调下降检查:前几十步loss应该快速且近似单调下降。如果loss不降反升或者剧烈震荡,可能是learning rate被设置的太大。如果loss完全不动,可能是梯度没有正确传播。
并行一致性检查:不同并行配置在相同数据、相同超参数下的 loss 曲线应该一致。DP=1和DP=2一致、TP=1和TP=2一致、PP=1和PP=2一致。如果不一致,说明通信或梯度计算有bug。
Gradient norm检查:正常训练中gradient norm应该在一个相对稳定的范围内(比如0.1-10)。如果突然飙升到几百或NaN,说明出现了数值不稳定。gradient clipping的值(通常1.0)应该在大部分步骤中不被触发。
显存检查:训练峰值显存应该和理论估算大致吻合(参数 + 梯度 + 优化器状态 + 激活)。如果显存远超预期,可能有tensor没有被及时释放,或者存在某种形式的内存泄漏。
SFT 特有的检查:SFT的起始loss应该远低于log(V)(因为基座模型已经训练过了),通常在1.5-2.5。如果起始loss接近log(V),说明模型权重没有正确加载。SFT最终loss如果低于1.0,大概率是过拟合。