并行优化一:DP/SP/TP
参考文章
1. https://zhuanlan.zhihu.com/p/2003423046342554380
2. https://zhuanlan.zhihu.com/p/1937449564509545940
前提知识
- 显卡通信方式
- Transformer 模型结构
参数说明
先要理清楚一些参数,下面的话也不是很严谨(比如 batch_size 仅仅对训练有用)。
batch_size:一次训练的样本数目,这个参数仅用于缩短训练时间,对模型最终效果无影响。seq_len:每句话有多少个 token,这个参数对模型最终效果有大影响。hidden_size:隐藏层的维度。
DP - 数据并行
batch_size 如果很大,单张 GPU 塞不下,所以就是拆分成多组,分别放到不同显卡中。
在 普通 Data Parallel / DDP 中:
1. 前向传播一般不需要 GPU 间通信
每张 GPU 都有一份完整模型副本,各自拿一部分 batch:
GPU0: batch[0:32] -> forward
GPU1: batch[32:64] -> forward
GPU2: batch[64:96] -> forward
GPU3: batch[96:128] -> forward
它们用的是相同参数副本,但计算自己的输入,所以前向阶段通常互不通信。
例外是模型里有跨 GPU 的操作,比如:
SyncBatchNorm
MoE routing
跨卡 loss / 对比学习 gather features
这些会在前向中通信。
2. 后向传播需要通信
后向时,每张 GPU 根据自己的 mini-batch 算出一份梯度:
GPU0: grad0
GPU1: grad1
GPU2: grad2
GPU3: grad3
为了让所有模型副本保持一致,需要把这些梯度做平均:
grad = (grad0 + grad1 + grad2 + grad3) / 4
上述操作就是 AllReduce 通信(显卡通信方式),所以 后向传播中需要通信。
SP - 序列并行
单条序列太长,导致中间激活值太大,单张 GPU 放不下。所以可以切分序列,放在不同显卡中。如果大白话理解,就是原来训练的时候用到 hello, world,现在切成 hello 和 world 分别放在不同显卡中。
为什么切序列不会有问题
第一次看这个感觉很奇怪,序列居然可以切分到不同设备。但其实,对于线性层而言,序列中各个单词是独立的,因为线性层就是每个向量自己乘上一个矩阵而已。
序列中的单词唯一需要通信的地方是计算 K/V,其实这个时候进行通信即可。此时 K/V 要在 GPU 之间交换,最简单的做法就还是各自广播(AllGather 通信),当然还有一些流水线通信做法,比如分为好几轮,第一轮是 0->1, 1->2, 2->0,第二轮是 0->2, 1->0, 2->1,这样可以流水线并行。
除了这个,最后要计算 softmax,即 softmax(QK^T),这个时候上面的流水线就不一定能有作用。不过也有对应的方法,叫 online softmax,不是先算完整 \(QK^T\) 再 softmax,而是一边收到 K/V 块,一边增量更新 softmax。这个就比较细节了,不谈了。
总之,序列并行中,通信只需要在计算 K/V 的时候进行即可。
TP - 张量并行
这里直接放 ChatGPT 的回答了。
它解决的问题是:单个模型层的参数或计算量太大,一张 GPU 放不下 / 算不过来。
直观理解
假设 Transformer 里有一个线性层:Y = XW,其中:
X: [batch, seq, hidden]
W: [hidden, 4 * hidden]
Y: [batch, seq, 4 * hidden]
如果 W 太大,可以把 W 按列切到多张 GPU:
W = [W0 | W1 | W2 | W3]
于是:
GPU0: Y0 = XW0
GPU1: Y1 = XW1
GPU2: Y2 = XW2
GPU3: Y3 = XW3
最后把结果拼起来:
Y = [Y0 | Y1 | Y2 | Y3]
这就是 列并行 Linear。
1. Attention 里切 head
比如有 32 个 attention head,4 张 GPU:
GPU0: head 0~7
GPU1: head 8~15
GPU2: head 16~23
GPU3: head 24~31
每张卡只算一部分 head,最后 Attention 输出再合并。
2. MLP 里切矩阵
Transformer MLP 通常是:
hidden -> 4 * hidden -> hidden
第一层可以按列切:
XW1 = [XW1_0, XW1_1, XW1_2, XW1_3]
第二层通常按行切:
Y = Y0W2_0 + Y1W2_1 + Y2W2_2 + Y3W2_3
这里最后需要一次 AllReduce 把各卡部分结果加起来。
TP 需要通信吗?
需要,而且比 DP 更频繁。DP 主要在 后向梯度同步 时通信;TP 在 每一层的前向和后向 都可能通信。
例如:
Column Parallel:
前向:通常输出分片,可暂时不通信
后向:需要同步输入梯度
Row Parallel:
前向:需要 AllReduce 合并输出
后向:某些梯度天然分片
所以:
TP 是层内并行,通信更细粒度、更频繁。
总结
这些并行方式通常是配合着用的。从上面也可以看出,要对 Transformer 有足够的了解,从矩阵运算中推出各个并行方式的时候,应该是如何通信、如何处理(是拼接,还是累加等)。具体细节还是推荐用的时候再去想,否则即使现在会了,第二天可能就忘了。