今天的工作量是实现了训练框架中的TP并行的全部工作。这部分的内容比较简单,但工程实现部分的学习相当有意思,在此记录一下。
Row/Column并行模式
这是对之前内容的复习。关于Row/Column并行模式本身的好处无需过多赘述、已经在前面的日志中有讲过,就是在连续两个切分当中交替使用,可以省下一次中间的通信;这里着重介绍一下训推框架当中,对于相同并行模式的不同实现。
在推理框架中,我们已经通过并行化完成了TP并行化。在推理框架的博客当中,作者曾经写下这样的感叹:TP并行是推理框架的并行化模式里最容易实现、代码改动最少的一个类别。不过实际上这有其特殊的原因:这是因为推理框架没有反向传播,只需要前向传播即可。但是,训练框架却是有反向传播的。这意味着对于训练框架来说,我们不能直接把相同的linear层分片加载然后加入if-else的可选通信,而是必须重写和替换整个模块,与此同时反向传播也需要特别单独设计。这里具体的把两种并行层的切分、通信模式和前向/反向传播模式整理并且列出来,以供参考和备忘:
记号约定:x 是输入,W 是权重(PyTorch 的 nn.Linear 存储为 [d_out, d_in]),前向计算为 y = xW^T。L 是 loss 标量。TP 的 world size 为 P,当前 rank 为 k。
ColumnParallelLinear
权重切分方式:W 沿 dim=0(输出维度)切分。Rank k 持有 W_k,shape [d_out/P, d_in]。
前向
每个 rank 持有完整输入 x [B, S, d_in],各自计算部分输出:
y_k = x · W_k^T shape: [B, S, d_out/P]
不需要通信。各 rank 的 y_k 是完整输出 y 沿最后一维的不同切片。
反向
上游传来的梯度 ∂L/∂y_k 的 shape 为 [B, S, d_out/P],只是完整梯度的一个切片。
权重梯度:每个 rank 独立计算自己那片 W_k 的梯度,不需要通信:
∂L/∂W_k = (∂L/∂y_k)^T · x shape: [d_out/P, d_in]
输入梯度:每个 rank 算出的是部分贡献,需要 all-reduce 汇总:
∂L/∂x|_k = ∂L/∂y_k · W_k shape: [B, S, d_in] (部分贡献)
∂L/∂x = Σ_{k=0}^{P-1} (∂L/∂y_k · W_k) ← all-reduce
完整的 ∂L/∂x = ∂L/∂y · W = Σ_k (∂L/∂y_k · W_k),每个 rank 只有其中一项。
对应的 autograd 算子:CopyToTPRegion
forward: f(x) = x (identity)
backward: f'(∂L/∂y) = AllReduce(∂L/∂y)
放在 ColumnParallel 的输入端。forward 时输入不需要通信(每个 rank 已有完整 x),backward 时把各 rank 的部分输入梯度加起来。
RowParallelLinear
权重切分方式:W 沿 dim=1(输入维度)切分。Rank k 持有 W_k,shape [d_out, d_in/P]。
前向
每个 rank 持有部分输入 x_k [B, S, d_in/P](来自上游 ColumnParallel 的切分输出),计算部分和:
y_k = x_k · W_k^T shape: [B, S, d_out] (部分和)
y = Σ_{k=0}^{P-1} y_k ← all-reduce
完整计算是 y = x · W^T = Σ_k (x_k · W_k^T),每个 rank 只算了其中一项。
反向
上游传来的梯度 ∂L/∂y 的 shape 为 [B, S, d_out],是完整的(因为 forward 的 all-reduce 使得输出在每个 rank 上完全一致)。
权重梯度:每个 rank 独立计算,不需要通信:
∂L/∂W_k = (∂L/∂y)^T · x_k shape: [d_out, d_in/P]
输入梯度:每个 rank 独立计算自己那片的梯度,不需要通信:
∂L/∂x_k = ∂L/∂y · W_k shape: [B, S, d_in/P]
∂L/∂x_k 只是对 x_k 的梯度(x 的第 k 片),每个 rank 用完整的 ∂L/∂y 和自己的 W_k 就能算出来,不需要其他 rank 的信息。
对应的 autograd 算子:ReduceFromTPRegion
forward: f(y_partial) = AllReduce(y_partial)
backward: f'(∂L/∂y) = ∂L/∂y (identity)
放在 RowParallel 的输出端。forward 时把各 rank 的部分和汇总,backward 时梯度直接透传。all-reduce 的反向就是 identity,每个 rank 贡献的部分和是独立的,梯度不需要拆分。
| 算子 | 前向传播 | 反向传播 | 出现位置 |
|---|---|---|---|
| CopyToTPRegion | identity | all-reduce | Column 的输入端 |
| GatherFromTPRegion(gather_output=True 时) | all-gather | split | Column 的输出端 |
| ReduceFromTPRegion | all-reduce | identity | Row 的输入端 |
| ScatterToTPRegion(scatter_input=True 时) | split | all-gather | Row 的输出端 |
其中,GatherFromTPRegion和ScatterToTPRegion只有在Col/Row切分两者不成对使用时需要。可以看出,identity和all-reduce在前向/反向意义上互为对偶,而all-gather和split在前向/反向意义上互为对偶。
工程实现技巧:工厂模式、注册表模式和函数修饰器
这两个方法在以前就有所耳闻,但一直没有在成规模的工程中真实的使用过。今日终于有幸真实的使用和学习它,自认为实现的还行?在此记录一下体会和想法。
工厂模式的核心思想是把"创建什么"和"怎么用"分离。调用方说"我要从一个参数权重加载一个module",不需要知道具体是什么,而使用工厂负责选择和创建。例如,传统的写法是用 if-else 链来做分派:
| |
问题在于这个函数同时做了两件事:"判断应该用什么策略"和"执行那个策略"。每新增一种并行类型,你都要回到这个函数里加 elif。如果判断逻辑和执行逻辑都很复杂,这个函数会膨胀到无法维护。工厂模式的精髓是:让每种策略自己"注册"自己,调用方只需要查表。
示例:关于权重加载的代码
为了可扩展性,可以进行三个层次的解耦。
第一层:策略本身(ShardLoader 协议)
| |
这定义了一个接口,即规范一个"任何能根据rank从handle中加载tensor的东西"。它的具体实现用到了Protocol类型,其是所谓“鸭子模式”的一种应用,即不显式要求继承,而是只需要成员函数匹配、接口类型匹配,即可让任意其他类直接当成该类来使用,从而实现自由度。ReplicateLoader 和 DimShardLoader 是两个具体实现。它们关心如何实现切分策略,不关心什么时候应该调用这部分功能。
第二层:工厂函数(决定用哪个策略)
| |
这个函数回答的问题是对于column类型的并行层,它的weight(或bias)应该用什么loader。它根据rule和参数后缀做决策,返回一个具体的ShardLoader实例。它不执行实际加载而是“制造”(更准确的说是根据进一步传入的信息返回了)一个合适的loader作为函数句柄,这就是工厂的含义。
第三层:注册表 + 装饰器(把工厂函数和类型名绑定)
| |
装饰器 @register_loader("column") 执行的功能是:在模块被import的时候,把 _column_loader 函数注册到全局字典 _LOADER_REGISTRY["column"] 中(虽然装饰器本身的目的不完全如此,但这里是这样使用的,一个python的小技巧)。之后任何地方想要获取column类型的loader,只需要:
| |
这样设计的结果没有任何 if-else。如果只是追求功能简洁性,其实完全可以不用装饰器,手动维护注册表:
| |
功能完全一样。但装饰器的好处是定义和注册在同一个位置,看到 @register_loader("column") 就知道"这个函数负责column类型",不需要去另一个地方查注册表。从另一个角度来说,整个注册表可以被看做一个二级的工厂函数的工厂函数,也即:注册表是工厂的工厂。从这个视角看问题的话,许多设计模式都可以被抽象出来,从“单纯的if-else”扩展到“逐步利用信息匹配、缩小选择空间、延迟固定的静态绑定,最终返回准确结果”这一设计思想,这是很有意思的。
更重要的是,新增类型时只需要在任意位置写一个新函数加上装饰器,不需要找到注册表所在的文件去修改它,这样就实现了最小的侵入性。
| |
只要这个文件被import,新的loader就自动注册了。这在软件工程哲学上被称为开放-封闭原则:对未来的扩展开放(新增类型不改已有代码),对过去的修改封闭(已有的注册逻辑不受影响)。
模型权重的加载
Megatron具有自己的切片权重格式,因此其实不用面对这个问题,直接按切片的结果加载就好。然而,对于我们这样的框架来说,自己说了算似乎是不太现实的。因此,更好的工程选择反而可能是侵入性的:利用现有的格式和架构,在需要修改的地方修改成自身的模块。另一方面,如果真的直接去加载一个大模型然后老老实实的广播,gpu显存会爆炸的!
因此需要一个解决方案。幸运的是,经过查阅资料,这个坑已经有人踩过了。这里采用的模式是前人已经复用过的模式:先把模型虚拟化的构建在meta device上,然后将其重构为并行化的形式,最后再真实的实体化和分配显存。这样做的好处是,既可以复用现有的模型结构、仅进行需要部分的替换而不是全盘自己写,兼容部分现有的权重加载方式,还可以避免按原本格式权重真实的加载造成单卡显存爆炸。
这一思想的具体实现并没有特别的复杂之处,但知道它存在并且可以利用这件事就不是那么容易了,因此在这里特别记录下来,感兴趣的读者可以自行做进一步的了解,如果能用上就是帮上忙了。