Femtotron开发日志 #2 训练框架中的TP并行模式,工厂模式、注册表模式和函数修饰器

今天的工作量是实现了训练框架中的TP并行的全部工作。这部分的内容比较简单,但工程实现部分的学习相当有意思,在此记录一下。

Row/Column并行模式

这是对之前内容的复习。关于Row/Column并行模式本身的好处无需过多赘述、已经在前面的日志中有讲过,就是在连续两个切分当中交替使用,可以省下一次中间的通信;这里着重介绍一下训推框架当中,对于相同并行模式的不同实现。

在推理框架中,我们已经通过并行化完成了TP并行化。在推理框架的博客当中,作者曾经写下这样的感叹:TP并行是推理框架的并行化模式里最容易实现、代码改动最少的一个类别。不过实际上这有其特殊的原因:这是因为推理框架没有反向传播,只需要前向传播即可。但是,训练框架却是有反向传播的。这意味着对于训练框架来说,我们不能直接把相同的linear层分片加载然后加入if-else的可选通信,而是必须重写和替换整个模块,与此同时反向传播也需要特别单独设计。这里具体的把两种并行层的切分、通信模式和前向/反向传播模式整理并且列出来,以供参考和备忘:

记号约定:x 是输入,W 是权重(PyTorch 的 nn.Linear 存储为 [d_out, d_in]),前向计算为 y = xW^T。L 是 loss 标量。TP 的 world size 为 P,当前 rank 为 k。

ColumnParallelLinear

权重切分方式:W 沿 dim=0(输出维度)切分。Rank k 持有 W_k,shape [d_out/P, d_in]

前向

每个 rank 持有完整输入 x [B, S, d_in],各自计算部分输出:

y_k = x · W_k^T        shape: [B, S, d_out/P]

不需要通信。各 rank 的 y_k 是完整输出 y 沿最后一维的不同切片。

反向

上游传来的梯度 ∂L/∂y_k 的 shape 为 [B, S, d_out/P],只是完整梯度的一个切片。

权重梯度:每个 rank 独立计算自己那片 W_k 的梯度,不需要通信:

∂L/∂W_k = (∂L/∂y_k)^T · x       shape: [d_out/P, d_in]

输入梯度:每个 rank 算出的是部分贡献,需要 all-reduce 汇总:

∂L/∂x|_k = ∂L/∂y_k · W_k        shape: [B, S, d_in]    (部分贡献)

∂L/∂x = Σ_{k=0}^{P-1} (∂L/∂y_k · W_k)    ← all-reduce

完整的 ∂L/∂x = ∂L/∂y · W = Σ_k (∂L/∂y_k · W_k),每个 rank 只有其中一项。

对应的 autograd 算子:CopyToTPRegion

forward:  f(x) = x                  (identity)
backward: f'(∂L/∂y) = AllReduce(∂L/∂y)

放在 ColumnParallel 的输入端。forward 时输入不需要通信(每个 rank 已有完整 x),backward 时把各 rank 的部分输入梯度加起来。

RowParallelLinear

权重切分方式:W 沿 dim=1(输入维度)切分。Rank k 持有 W_k,shape [d_out, d_in/P]

前向

每个 rank 持有部分输入 x_k [B, S, d_in/P](来自上游 ColumnParallel 的切分输出),计算部分和:

y_k = x_k · W_k^T      shape: [B, S, d_out]    (部分和)

y = Σ_{k=0}^{P-1} y_k      ← all-reduce

完整计算是 y = x · W^T = Σ_k (x_k · W_k^T),每个 rank 只算了其中一项。

反向

上游传来的梯度 ∂L/∂y 的 shape 为 [B, S, d_out],是完整的(因为 forward 的 all-reduce 使得输出在每个 rank 上完全一致)。

权重梯度:每个 rank 独立计算,不需要通信:

∂L/∂W_k = (∂L/∂y)^T · x_k       shape: [d_out, d_in/P]

输入梯度:每个 rank 独立计算自己那片的梯度,不需要通信:

∂L/∂x_k = ∂L/∂y · W_k           shape: [B, S, d_in/P]

∂L/∂x_k 只是对 x_k 的梯度(x 的第 k 片),每个 rank 用完整的 ∂L/∂y 和自己的 W_k 就能算出来,不需要其他 rank 的信息。

对应的 autograd 算子:ReduceFromTPRegion

forward:  f(y_partial) = AllReduce(y_partial)
backward: f'(∂L/∂y) = ∂L/∂y       (identity)

放在 RowParallel 的输出端。forward 时把各 rank 的部分和汇总,backward 时梯度直接透传。all-reduce 的反向就是 identity,每个 rank 贡献的部分和是独立的,梯度不需要拆分。

算子前向传播反向传播出现位置
CopyToTPRegionidentityall-reduceColumn 的输入端
GatherFromTPRegion(gather_output=True 时)all-gathersplitColumn 的输出端
ReduceFromTPRegionall-reduceidentityRow 的输入端
ScatterToTPRegion(scatter_input=True 时)splitall-gatherRow 的输出端

其中,GatherFromTPRegion和ScatterToTPRegion只有在Col/Row切分两者不成对使用时需要。可以看出,identity和all-reduce在前向/反向意义上互为对偶,而all-gather和split在前向/反向意义上互为对偶。

工程实现技巧:工厂模式、注册表模式和函数修饰器

这两个方法在以前就有所耳闻,但一直没有在成规模的工程中真实的使用过。今日终于有幸真实的使用和学习它,自认为实现的还行?在此记录一下体会和想法。

工厂模式的核心思想是把"创建什么"和"怎么用"分离。调用方说"我要从一个参数权重加载一个module",不需要知道具体是什么,而使用工厂负责选择和创建。例如,传统的写法是用 if-else 链来做分派:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# 硬编码分支——每加一种类型就要改这个函数
def load_param(module, handle, rank, world_size):
    if isinstance(module, ColumnParallelLinear):
        return handle[rank * chunk : (rank+1) * chunk, :]
    elif isinstance(module, RowParallelLinear):
        return handle[:, rank * chunk : (rank+1) * chunk]
    elif isinstance(module, VocabParallelEmbedding):
        return handle[rank * chunk : (rank+1) * chunk, :]
    elif isinstance(module, nn.LayerNorm):
        return handle[:]
    else:
        return handle[:]

问题在于这个函数同时做了两件事:"判断应该用什么策略"和"执行那个策略"。每新增一种并行类型,你都要回到这个函数里加 elif。如果判断逻辑和执行逻辑都很复杂,这个函数会膨胀到无法维护。工厂模式的精髓是:让每种策略自己"注册"自己,调用方只需要查表。

示例:关于权重加载的代码

为了可扩展性,可以进行三个层次的解耦。

第一层:策略本身(ShardLoader 协议)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class ShardLoader(Protocol):
    def load(self, handle, rank: int, world_size: int) -> Tensor: ...
    
class ReplicateLoader:
    """完整加载,所有 rank 拿到一样的副本。"""
    def load(self, handle, rank, world_size):
        return handle[:]

class DimShardLoader:
    """沿固定维度切分。"""
    def __init__(self, dim: int):
        self.dim = dim

    def load(self, handle, rank, world_size):
        shape = handle.get_shape()
        size = shape[self.dim]
        assert size % world_size == 0, f"dim {self.dim} size {size} not divisible by {world_size}"
        chunk = size // world_size
        slices: list[slice] = [slice(None)] * len(shape)
        slices[self.dim] = slice(rank * chunk, (rank + 1) * chunk)
        return handle[tuple(slices)]

这定义了一个接口,即规范一个"任何能根据rank从handle中加载tensor的东西"。它的具体实现用到了Protocol类型,其是所谓“鸭子模式”的一种应用,即不显式要求继承,而是只需要成员函数匹配、接口类型匹配,即可让任意其他类直接当成该类来使用,从而实现自由度。ReplicateLoaderDimShardLoader 是两个具体实现。它们关心如何实现切分策略,不关心什么时候应该调用这部分功能。

第二层:工厂函数(决定用哪个策略)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
@register_loader("column")
def _column_loader(rule: ParallelRule, suffix: str) -> ShardLoader:
    return DimShardLoader(dim=0)

@register_loader("row")
def _row_loader(rule: ParallelRule, suffix: str) -> ShardLoader:
    # row parallel: weight 切 dim 1;bias 不切(每个 rank 加完整 bias,最后 all-reduce 时会重复加,所以需要其他处理)
    if suffix == ".bias":
        return ReplicateLoader()
    return DimShardLoader(dim=1)


@register_loader("vocab_embed")
def _vocab_loader(rule: ParallelRule, suffix: str) -> ShardLoader:
    return DimShardLoader(dim=0)


@register_loader("replicate")
def _replicate_loader(rule: ParallelRule, suffix: str) -> ShardLoader:
    return ReplicateLoader(

这个函数回答的问题是对于column类型的并行层,它的weight(或bias)应该用什么loader。它根据rule和参数后缀做决策,返回一个具体的ShardLoader实例。它不执行实际加载而是“制造”(更准确的说是根据进一步传入的信息返回了)一个合适的loader作为函数句柄,这就是工厂的含义。

第三层:注册表 + 装饰器(把工厂函数和类型名绑定)

1
2
3
4
5
6
7
_LOADER_REGISTRY: dict[str, LoaderFactory] = {}

def register_loader(kind: str):
    def deco(fn: LoaderFactory) -> LoaderFactory:
        _LOADER_REGISTRY[kind] = fn
        return fn
    return deco

装饰器 @register_loader("column") 执行的功能是:在模块被import的时候,把 _column_loader 函数注册到全局字典 _LOADER_REGISTRY["column"] 中(虽然装饰器本身的目的不完全如此,但这里是这样使用的,一个python的小技巧)。之后任何地方想要获取column类型的loader,只需要:

1
2
3
factory = _LOADER_REGISTRY[rule.parallel_type]   # 查表
loader = factory(rule, ".weight")                  # 调用工厂
tensor = loader.load(handle, rank, world_size)     # 执行加载

这样设计的结果没有任何 if-else。如果只是追求功能简洁性,其实完全可以不用装饰器,手动维护注册表:

1
2
3
4
5
6
_LOADER_REGISTRY = {
    "column": _column_loader,
    "row": _row_loader,
    "vocab_embed": _vocab_loader,
    "replicate": _replicate_loader,
}

功能完全一样。但装饰器的好处是定义和注册在同一个位置,看到 @register_loader("column") 就知道"这个函数负责column类型",不需要去另一个地方查注册表。从另一个角度来说,整个注册表可以被看做一个二级的工厂函数的工厂函数,也即:注册表是工厂的工厂。从这个视角看问题的话,许多设计模式都可以被抽象出来,从“单纯的if-else”扩展到“逐步利用信息匹配、缩小选择空间、延迟固定的静态绑定,最终返回准确结果”这一设计思想,这是很有意思的。

更重要的是,新增类型时只需要在任意位置写一个新函数加上装饰器,不需要找到注册表所在的文件去修改它,这样就实现了最小的侵入性。

1
2
3
4
# 未来某天加入 FP8 支持,在 precision/fp8.py 里
@register_loader("column_fp8")
def _column_fp8_loader(rule: ParallelRule, suffix: str) -> ShardLoader:
    return DimShardLoader(dim=0)  # 切分方式一样,只是精度不同

只要这个文件被import,新的loader就自动注册了。这在软件工程哲学上被称为开放-封闭原则:对未来的扩展开放(新增类型不改已有代码),对过去的修改封闭(已有的注册逻辑不受影响)。

模型权重的加载

Megatron具有自己的切片权重格式,因此其实不用面对这个问题,直接按切片的结果加载就好。然而,对于我们这样的框架来说,自己说了算似乎是不太现实的。因此,更好的工程选择反而可能是侵入性的:利用现有的格式和架构,在需要修改的地方修改成自身的模块。另一方面,如果真的直接去加载一个大模型然后老老实实的广播,gpu显存会爆炸的!

因此需要一个解决方案。幸运的是,经过查阅资料,这个坑已经有人踩过了。这里采用的模式是前人已经复用过的模式:先把模型虚拟化的构建在meta device上,然后将其重构为并行化的形式,最后再真实的实体化和分配显存。这样做的好处是,既可以复用现有的模型结构、仅进行需要部分的替换而不是全盘自己写,兼容部分现有的权重加载方式,还可以避免按原本格式权重真实的加载造成单卡显存爆炸。

这一思想的具体实现并没有特别的复杂之处,但知道它存在并且可以利用这件事就不是那么容易了,因此在这里特别记录下来,感兴趣的读者可以自行做进一步的了解,如果能用上就是帮上忙了。