今天完成的部分是数据并行的相关模块。在分布式训练框架当中,它通常被简写为DDP:即Distributed Data Parallelism,分布式数据并行。其实第一个D加不加都一样,因为数据并行在非分布式上自然没有什么意义:本来就一张卡,“并行”这个概念就不再在GPU层面存在了。此外,这部分是崭新的内容,涉及到在Pico-vLLM中并未实现的并行模式——毕竟推理框架除非是大规模且工业级的多实例、高并发,否则一般也不涉及到数据并行这一层。
这一部分模块的组件由这样几个成分组成:首先是DistributedDataLoader,其下又包括了Dataset、DistributedSampler、Collator三个子组件。然后是GradientSynchronizer,最后是离线模式的preprocess数据脚本,共三个组件。它们的功能各不相同,而且各自涉及到许多训练当中有意思而且具有一定重要性的细节。下面逐个讲在femtotron当中,每个组件的功能和具体的设计决策。这些设计模式和接口风格参考了Pytorch的相应模块,包括Pytorch DataLoader等。但很明显,我们并不能照搬它们的设计:否则就做不了真正的所谓分布式了。
整体架构说明
首先看看每个组件在整个计算流当中位于什么位置、起到什么作用。从数据流和计算流的角度看:
数据流:PackedDataset → DistributedSampler → Collator → micro batch → model.forward → loss
梯度流:loss.backward → param.grad → GradientSynchronizer.sync_gradients → MixedPrecisionManager.step
更详细具体的展开来说,数据的dataset本身,在进入训练之前和训练过程之中,会完整经历以下过程:
原始文档 (jsonl / parquet / HF Hub)
↓
[离线一次] HuggingFace datasets.load_dataset
↓
[离线一次] tokenizer.map(batched=True, num_proc=N) # 多进程并行 tokenize
↓
[离线一次] concat 所有文档 + 插入 EOS
↓
[离线一次] reshape 成 [N, seq_len]
↓
[离线一次] torch.save 到磁盘 (packed_4k.pt)
↓
═══════════ 离线 / 在线分界线 ═══════════
↓
[训练启动] PackedDataset.__init__: torch.load(path, mmap=True) # 零拷贝 mmap
↓
[每个 epoch] DistributedDataLoader.set_epoch(epoch)
↓
[每个 epoch] sampler.set_epoch(epoch) + 重置 _start_offset
↓
[每个 step] sampler.__iter__: 用 seed+epoch 算全局 shuffle 顺序
↓
[每个 step] sampler 切片: indices[dp_rank * per_rank : (dp_rank+1) * per_rank]
↓
[每个 step] sampler yield idx (跳过 _start_offset 之前)
↓
[每个 step] DataLoader 主进程把 idx 分发给 worker 子进程
↓
[每个 step] worker: dataset[idx] # mmap 触发 page fault, 读对应 seq_len tokens
↓
[每个 step] worker: collator(samples) # stack 成 [batch, seq_len], 加 labels
↓
[每个 step] worker → 主进程 (pin_memory 自动)
↓
[每个 step] DistributedDataLoader yield batch + sampler.advance(micro_batch_size)
↓
[每个 step] trainer: batch.to(device, non_blocking=True) # CPU → GPU 异步拷贝
↓
model(input_ids=..., labels=...) → loss
而在一次反向传播当中,梯度及参数的更新会经历以下过程:
loss
↓
[forward 完成] loss = loss / grad_accum_steps # 缩放, 让累加是平均
↓
[micro-step i] loss.backward()
↓
[micro-step i] PyTorch autograd 反向遍历计算图
↓
[micro-step i] 各 param.grad 累加 (compute_param 是 bf16, grad 也是 bf16)
↓
[micro-step i, i < N-1] 包在 grad_sync.no_sync() 里 → 跳过 DP 同步
↓
[micro-step N-1] 不包 no_sync, 正常累加 grad
↓
═══════════ 所有 micro-step 完成 ═══════════
↓
grad_sync.sync_gradients() # DP 维度同步
↓
对每个 compute_param.grad 在 dp_group 上 all_reduce(ReduceOp.AVG) # bf16 通信
↓
═══════════ 所有 DP rank 看到一致的全局平均梯度 ═══════════
↓
mp_manager.step() 被调用
↓
[per param] GradAccumulator.finalize: bf16 grad → fp32 (cast)
↓
[per param] ParamHandle.assign_grad: 把 fp32 grad 装到 master_param.grad 上
↓
GradTransform 链: ClipGradNorm 等 (在 fp32 grad 上做)
↓
inner_optimizer.step() # AdamW 在 fp32 master 上更新
↓
[per param] ParamHandle.sync_master_to_compute: master(fp32) → compute(bf16)
↓
[per param] 清零 compute_param.grad 和 master_param.grad
↓
═══════════ 一个 optimizer step 完成 ═══════════
↓
scheduler.step() # 更新 lr
↓
trainer 进入下一个 step
可以看到,几个组件在这里分别扮演了不同的角色。下面简短的逐个介绍。
Preprocess脚本。它是负责这个离线阶段的工具。它的作用是把原始文档经过tokenize、拼接EOS、截断成定长序列、最终torch.save到磁盘,产出PackedDataset能直接mmap加载的.pt文件。这一步只需要对同一份数据集跑一次,之后所有训练run都复用同一份产出,训练机器上理论上甚至可以不需要安装tokenizer。
PackedDataset。它的作用是把预处理好的定长token序列以mmap的方式提供给sampler查表。在我们的设计当中,它实际上很简单,目前只接受固定定长的、已经被tokenize过的token序列。这么设计是因为LLM的原始训练数据是变长文档,但模型吃的是定长序列,拼接和截断的逻辑如果放在训练时在线做,dataset就会变成有状态的(上一个文档没用完的部分要carry到下一次调用),而有状态的dataset在DataLoader的多worker环境下(在后面会讲解机制)会出现各worker状态不一致的问题,resume时状态也难以序列化。这就意味着想要比较简洁优美的解决这个问题,其实需要把packing前置到离线阶段,让训练时的dataset退化为一个纯粹无状态的、按idx查表的容器。preprocess脚本就是负责这个离线阶段的工具。
DistributedSampler。它的作用是决定每个 DP rank 在每个 epoch 看到哪些样本、以什么顺序看。这是因为数据并行要求不同 rank 处理不同的数据分片,而同一 TP group 内的 rank 又必须看到完全相同的输入,从而导致需要有一个按 dp_rank 而非全局 rank 进行分片的 sampler,且所有 rank 的全局 shuffle 顺序必须一致——各 rank 独立 shuffle 会导致某些样本被重复处理、某些样本被遗漏,梯度更新的统计意义就不对了。然后是 PackedDataset,它的作用是把预处理好的定长 token 序列以 mmap 的方式提供给 sampler 查表。这么设计是因为 LLM 的训练数据是变长文档,但模型吃的是定长序列,拼接和截断的逻辑如果放在训练时在线做,dataset 就会变成有状态的(上一个文档没用完的部分要 carry 到下一次调用),而有状态的 dataset 在 DataLoader 多 worker 环境下会出现各 worker 状态不一致的问题,resume 时状态也难以序列化,从而导致需要把 packing 前置到离线阶段,让训练时的 dataset 退化为一个纯粹的、无状态的、按 idx 查表的容器。
Collator。它的作用是把sampler选出的一批样本组装成模型能直接消费的batch tensor。因为预pack之后所有样本已经是等长的,collator在预训练场景下实际上只做一次stack和labels的复制,非常轻量;但它仍然作为独立的可注入组件存在,这是因为不同训练任务(预训练、SFT、DPO)对batch的组装方式差异很大,SFT需要padding和loss mask,DPO需要成对样本,把collator写死意味着换任务时要改dataloader内部代码。
DistributedDataLoader。负责把前面三者以固定顺序组合起来,形成完整逻辑。它的作用是把上面三个组件和PyTorch的DataLoader粘在一起,对外暴露一个普通的iterator接口,同时管理sampler的状态推进和epoch切换。它本身几乎不包含逻辑,IO层面的多worker预取、pin memory等优化委托给内部的PyTorch DataLoader处理。
GradientSynchronizer,它负责在不同DP间同步算出的梯度情况。在所有micro-step的backward完成后,它负责对每个参数的梯度在dp_group上做all-reduce,让所有DP rank看到一致的全局平均梯度,然后optimizer才能正确地做参数更新。这是因为每个DP rank只看到了全局数据的一个子集,各自算出的梯度只是对各自子集的估计。如果不同步,每个DP rank持有副本的参数差距越来越大,实际上就是相当于各自训练了不同的模型,几步之后就相互发散,变成同一个模型家族的不同衍生品了。值得一提的是它还提供了no_sync接口,用于在gradient accumulation的中间micro-step跳过同步,只在最后一个micro-step做一次。这能把通信次数从grad_accum_steps次降到1次,节省grad_accum_steps倍的通信量。如果只有TP并行而没有DP并行,它就会按no-op处理,即什么都不做。
总结一下,它们的分工介入顺序如下:变长文档 → preprocess 离线处理 → 训练开始 → PackedDataset 提供 mmap 访问 → Sampler 决定采样范围和顺序 → Collator 组装 batch → GradSync 同步梯度。DistributedDataLoader不直接介入工作,以组合器的形式存在。
DistributedDataLoader
这部分主要是逻辑的拼接,复杂度在各个子组件里。几乎没什么好写的,按部就班实现即可。
PackedDataset + preprocess
需要注意的是,这两者是互相耦合的两个组件,因为后者的输出直接作为前者的输入。在本次开发当中,采用的约定是重preprocess,轻PackedDataset。主要的处理逻辑放在preprocess,即离线过程中完成,而PackedDataset处理已经足够规整的数据。
DistributedSampler
在这部分需要注意的东西较多,具体如下:
- 全局shuffle一致性的实现和随机数生成问题
在dp的采样过程当中,并没有一个集中式的协调者来把数据分发给每个worker,每个worker是自己取的。在没有打乱的时候,这很好协调,但在打乱的情况下需要额外的机制保证一致性。因此,所有rank必须使用 torch.Generator + seed + epoch 算出完全相同的全局打乱序列,然后各取分片。必须用显式generator,不能用全局random state。否则,一些不经意的外部随机生成器的使用就可能会导致外部代码污染,造成难以排查的静默错误。
- 按dp_rank分片而非全局rank
这个是和PyTorch标准DistributedSampler的核心差异,也是为什么需要自己把这些组件轮子重写一遍的原因。PyTorch默认按 dist.get_rank() 分片,在TP+DP混合并行下会让同一TP group内的rank拿到不同数据,实际上就是不默认兼容TP的同时启用。这会直接破坏TP的正确性,因此这部分需要自己写。
- dataloader的worker机制
这是为了避免IO阻塞而设计的机制,具有prefetch factor和worker num的概念。worker的本质是fork类型创建的子进程,子进程继承主进程的全部内存(dataset对象、import的模块、文件描述符等)。Linux下fork是COW(copy-on-write)的,实际不真复制内存。不过,这也意味着子进程不能直接修改主进程的内容(会COW),因此其状态彼此无法简单知晓和同步,这需要在设计的时候格外小心。
- drop vs pad的策略选择问题
数据集大小不能整除dp_size时需要进行特别处理。具体来说有两个策略:第一是drop策略,就是把多余的数据丢掉直到整除;第二是pad策略,就是用开头的数据把缺少的部分重复补齐直到整除。预训练一般采用前者,主要原因是数据足够多,一般来说甚至跑不满一个epoch(也不建议多epoch,为了避免形成记忆)。SFT的小数据集则一般使用pad策略。
- load_state_dict的配置一致性检查
dp_size 变了(比如从 4 卡恢复到 8 卡),_start_offset 在新分片方式下没有意义,需要通过assert或者其他的方式进行一下配置的检查。在工业界里,通常不会直接用这种方式,而是做兼容性处理,即resharding:改变参数的分配方式让加载可以在不同并行度的集群上实现。resharding在工业级当中是一个相当有必要而且繁重的课题,bytedance为此发过论文,感兴趣的读者可以参考ByteCheckpoint的论文ByteCheckpoint以及其他相关工作(仓库:ByteCheckpoint),思想相当有意思,不过边界情况也相当麻烦。在我们的项目当中暂时不处理它,毕竟规模远远超出了单人或者少数人能够处理的范围。
Collator
这里遇到了一个相当难排查的静默bug:千万不要给Collator设置默认值,否则会造成问题。原因是测试脚本里为了代码环境的简洁性,没有传入Collator,全部按照默认值处理了默认传入参数。对于DistributedSampler等组件来说这没有什么问题。但在进行测试的时候,结果不对,排查错误信息发现是因为Collator被赋值了一个占位符默认值。虽然查出来这个问题很容易,但结果是测试脚本最后还是得重写,白费功夫。从设计哲学的角度来说,这主要是因为Collator并不应该有一个标准的默认值:如果有,应该是什么语义,做什么?对于不同问题来说它是完全不同的,并不存在一个可以安全fallback的选项。因此,直接把它作为不可默认为空的必选参数是更合适的做法。
GradientSynchronizer
- grad_accum_steps次数的设置
需要注意的是,技术上grad_accum_steps其实并不能设置的太大。一来它会影响收敛速度,二来,基于grad_accum_steps的累加是在bf16上进行的,而不是最终finalize之后的fp32。这意味着如果它累积了太多步数,很快就会出现和之前混合精度训练的博客日志里提到过的,bf16精度训练相同的问题:大数吃小数,梯度累加不正确,误差累积导致模型训练质量下降。因此,grad_accum_steps应该最好不要特别大。在grad_accum_steps=4的情况下,测试得到的精度误差大概如下:
============================================================
Loss 对比
============================================================
mbs=8, accum=1 vs mbs=4, accum=2:
Step Baseline Other Diff
──────────────────────────────────────────────
1 6.932020 6.932019 0.00000024
4 6.931188 6.931187 0.00000072
7 6.931248 6.931248 0.00000048
10 6.930811 6.930812 0.00000143
13 6.931348 6.931348 0.00000024
Max diff: 0.00000167
✓ 一致性 (threshold=0.02)
mbs=8, accum=1 vs mbs=2, accum=4:
Step Baseline Other Diff
──────────────────────────────────────────────
1 6.932020 6.932020 0.00000024
4 6.931188 6.931188 0.00000012
7 6.931248 6.931249 0.00000095
10 6.930811 6.930811 0.00000012
13 6.931348 6.931349 0.00000083
Max diff: 0.00000274
大约10^-6次方量级。这个数量级还是很可以接受的,但继续增加就比较难说,需要测试一下。
- bf16通信 vs fp32通信
可以选择在bf16 grad上同步再upcast到fp32给optimizer,而不是先upcast再同步。通信量减半,NCCL对bf16 reduce有专门优化,精度损失在实践中可忽略。本计划使用后者fp32,但经过查询,发现其实bf16通信才是主流方案,遂改用bf16的通信方法。最后,精度差异的测试也说明,这个误差的确是可以接受的。
其他的注意事项
这部分的开发相对顺利,没有遇到特别难以排查的bug。不过,排查过程中的死锁问题已经初见端倪。在测试功能性的过程中,“调用的时候只调用了rank0的通信导致死锁”,或者反过来“generator的参数设置错误导致其他rank调用了不该调用的通信导致死锁”的问题比比皆是。这些问题目前比较好解决,但随着框架复杂度升高,如何确保一致性、尽可能少的避免此类错误真的出现,是一件需要小小设计的事情。
另一方面来说,可以看到预训练框架的查错和pico-vllm这种推理框架的主要错误类型有显著的不同:推理框架要微观的多,涉及到底层数据排布、tensor的连续/非连续性,CUDA/Triton kernel的指针和参数校对等等问题;而训练框架相对来说更宏观一些,主要是设计的对齐、通信和死锁、不同feature的兼容等等层面。从发现和定位难度来说,pico-vllm的错误明显更好发现、难定位,而femtotron当中则是难发现、好定位。这同样能够说明两个项目的明显不同的趋势,而通过实操了解他们的特征区别正是做项目的意义。