必知必会:大模型训练通信开销计算详解与面试指南
AI-Compass 致力于构建最全面、最实用、最前沿的AI技术学习和实践生态,通过六大核心模块的系统化组织,为不同层次的学习者和开发者提供完整学习路径。 通信开销如同隐匿的丝线,牵动着大模型训练的每一个环节。在分布式训练中: 想象你在一个快递分拣中心工作,有 4 个分拣员(GPU),每人负责一个区域的包裹: 一个发送者,多个接收者 将一个GPU的完整数据复制到所有其他GPU上。 应用场景:模型初始化时,将主节点的参数广播到所有工作节点。 一个发送者,多个接收者(数据分片) 将一个GPU的数据切片后分发给不同GPU,每个GPU获得不同的部分。 与Broadcast的区别: 多个发送者,一个接收者 将多个GPU的数据收集到一个GPU上。 多个发送者,一个接收者(带计算) 将多个GPU的数据规约运算(如求和SUM、求最大值MAX、求乘积PROD)后发送到一个GPU。 多个发送者,多个接收者 = Gather + Broadcast 每个GPU都收集到所有GPU的数据。 多个发送者,多个接收者 在所有GPU上按维度执行相同的Reduce操作,再将结果发散到集群内所有GPU上。 多个发送者,多个接收者 = Reduce + Broadcast = Reduce-Scatter + All-Gather 在集群内所有GPU上都执行相同的Reduce操作,并将结果发送到所有GPU上。 多个发送者,多个接收者(完全交换) 每个GPU将自己的数据发散到所有GPU,同时收集所有GPU的数据。 与All-Gather的区别: <!-- trick-image:start idx=1 platform=blog --> 组合关系: 想象你在一个班级里做小组作业: 朴素方式(班长汇总法):30 个同学各做一份答案,全部交给班长。班长一个人汇总 29 份答案,算出最终版本,再抄 29 份发回去。班长累得半死,其他人在旁边等着——这就是单点瓶颈。 Ring 方式(传话游戏法):30 个同学围成一个圈,每人把答案分成 30 段。第一轮,每人把第 1 段传给右边的同学并加上自己的;第二轮继续传递第 2 段……经过 29 轮传递,每人手上都有了完整汇总结果。虽然传了很多轮,但每轮每人只处理一小段,没有人特别累。 掌握了直观类比后,接下来我们用精确的数学语言来刻画通信量的计算方法。这样在实际工作中就能准确预估不同配置下的通信开销。 在数据并行策略中: 假设有N个GPU参与数据并行,其中: 通信开销公式: $$ 通信过程: 问题:当N增大时,整个系统的总通信量约为 2 × N × Φ,且收集梯度的GPU成为瓶颈。 数值计算示例: 假设用 4 张 GPU 训练 LLaMA-7B(Φ = 7B): $$ 核心思想:将数据切分成N份,通过环形拓扑结构进行传输,避免单点瓶颈。 两个阶段: 阶段1:Reduce-Scatter 阶段2:All-Gather 可视化流程: 为了更直观理解Ring-All-Reduce的工作原理,下图展示了4个GPU环形拓扑下的两阶段数据流动: 上图说明: 单个GPU通信量公式: $$ 推导过程: $$ 整个系统通信量: $$ 数值计算示例: 同样 4 张 GPU 训练 LLaMA-7B(Φ = 7B): $$ <!-- trick-image:start idx=2 platform=blog --> 想象你和同事一起做一张超大的拼图(矩阵乘法): 拼图太大了,一个人的桌子放不下。于是你们把拼图纵向切成两半,一人拼左半边,一人拼右半边。拼完各自的部分后,需要把结果拼到一起看全貌——这就是前向传播的 All-Reduce。 发现有几块拼错了(反向传播计算梯度),你们各自修正自己那半边的错误,修好后再汇总一次看哪些地方还需要调整——这就是反向传播的 All-Reduce。 一个 Transformer 层有两大组件(注意力 + FFN),每个组件都要"拼一次 + 修一次",所以一共需要 4 次 All-Reduce,总通信量为 8 × b × s × h。 <!-- trick-image:start idx=3 platform=blog --> 理解了"拼图"的直觉后,下面我们看看 Megatron-LM 如何在工程上实现这个切分策略,以及通信量的精确计算方法。 张量并行策略的基本思路: 以Transformer模型为例,主要涉及: 前馈神经网络层的计算分为两步: 其中: 张量切分方式: 可视化数据流: 下图展示了2-GPU张量并行下FFN层的矩阵切分和通信流程: 上图说明: 通信过程示意: FFN 层通信开销公式: 每次 All-Reduce 的通信量等效于一次 Reduce-Scatter + 一次 All-Gather,数据大小为 $b \times s \times h$: $$ 前向传播 1 次 + 反向传播 1 次 = 2 次 All-Reduce: $$ 数值计算示例: 假设 LLaMA-7B 的一个 Transformer 层,使用 2 路张量并行: $$ 多头注意力层在张量并行下的通信结构与 FFN 层相同: $$ 对一个包含前馈网络层和多头注意力层的Transformer来说: $$ 数值计算示例(续上例): $$ 想象一条汽车生产流水线: 车间 1 负责焊接车架,车间 2 负责安装发动机,车间 3 负责喷漆,车间 4 负责装内饰。每辆车依次经过 4 个车间。 有了"汽车生产流水线"的直觉,现在我们用数学公式来精确计算流水线并行的通信量,并理解 GPipe 是如何通过 micro-batch 提升效率的。 对于朴素的流水线并行,其主要思想是将模型的不同层进行拆分,然后放到不同的GPU上。 前向传播过程: 反向传播过程: GPipe是一种经典的流水线并行方法,通过引入micro-batch(微批次)处理和激活值重算机制,有效解决了朴素流水线并行所存在的GPU利用率过低以及中间结果消耗过大的问题。 通信开销公式: 假设: 前向传播通信量(N-1 个中间结果): $$ 反向传播通信量(N-1 个偏导数): $$ 流水线并行总通信量: $$ 数值计算示例: 假设 LLaMA-7B(32 层)使用 4 路流水线并行,每个 GPU 放 8 层: $$ 对比:同配置下张量并行总通信量为 16 GB,流水线并行仅 384 MB——但流水线并行有"气泡"(idle time)问题。 想象你和 7 个同学一起背一本 700 页的词典: 理解了"词典分工"的类比后,现在我们详细剖析 ZeRO 三个层次的通信流程和通信量计算,看看为什么 ZeRO-3 要付出 50% 的额外通信代价。 特点: 通信过程: Reduce-Scatter:对梯度进行聚合,每个GPU得到部分聚合梯度 All-Gather:从其他GPU上把更新好的部分模型参数取回来 ZeRO-1 单卡通信量公式: $$ 特点: 通信过程: 前向传播后,反向传播时需要Reduce-Scatter获取其他GPU的梯度 All-Gather:从其他GPU上把完成梯度更新的部分模型参数取回来 ZeRO-2 单卡通信量公式: $$ 特点: 通信过程: 前向传播: 每一层完成前向传播后,立即把不属于自己维护的模型参数丢弃 反向传播: 每一层完成反向传播后,立即把不属于自己维护的模型参数丢弃 参数更新: 需要Reduce-Scatter获得完整的聚合梯度用于更新模型参数 ZeRO-3 单卡通信量公式: $$ <!-- trick-image:start idx=4 platform=blog --> ZeRO-3 相对 ZeRO-2 的额外通信量比例: $$ 数值计算示例: 以 LLaMA-7B(Φ = 7B)在 8 张 GPU 上训练为例: $$ $$ 结论:ZeRO-3 多付出 14 GB 通信开销,但单卡显存从 112 GB 降到 14 GB,使 7B 模型可以在 8 张 A100-80G 上训练。 <!-- trick-image:start idx=5 platform=blog --> 下图展示了并行策略的选型决策流程: 上图展示了从单卡到 3D 并行的选型路径:优先考虑简单策略,逐步引入更复杂的并行方式。 下图展示了 3D 并行中三个维度的协同关系: 上图展示了 32 GPU 的 3D 并行布局:TP=4 使用机内 NVLink,PP=4 用于层间传递,DP=2 用于机间梯度同步。 配置示例(32 GPU): 3D 并行各维度通信量公式: $$ $$ $$ 设计原则: 答案: 详细说明: 公式: $$ 答案: 详细说明: 答案: 详细说明: 公式: $$ ZeRO-1/2 通信量与 DP 相同的原因:Reduce-Scatter + All-Gather = All-Reduce。ZeRO-3 多出的 1Φ 来自前向和反向传播各一次 All-Gather。 答案: 详细说明: 公式: $$ 答案: 详细说明: 答案: 详细说明: 公式: $$ 答案: 详细说明: 公式: $$ 答案: 详细说明: 核心权衡:ZeRO-3 用 50% 额外通信换取线性显存扩展能力。 答案: 详细说明: 设计原则:TP 优先机内(带宽最高)→ PP 次之 → DP 用于机间(频率最低)。 答案: 详细说明: 答案: 详细说明: 公式: $$ 其中 $\alpha$ 为单次通信延迟,$\beta$ 为链路带宽。 答案: 详细说明: 本节面试题来源于互联网公开的大厂面试真题和高频考点。 出处:字节跳动/阿里云 AI Infra 岗位面试真题 答案: 详细说明: 出处:Meta/微软 LLM 训练工程师面试高频题 答案: 详细说明: 出处:综合预测题(结合 LLaMA-70B 等开源模型训练实践) 答案: 详细说明: 通信量估算: $$ $$ $$ <!-- Reviewed: 2026-02-13, 深度重新审校:增加2个核心Mermaid图(Ring-All-Reduce两阶段流程+张量并行矩阵切分数据流)/原有2个Mermaid图(选型流程+3D并行配置)/丰富的通俗化类比(快递分拣/课堂小组/拼图/流水线/词典)/完整公式+符号表+数值示例/15个面试题/总计4个Mermaid图+1400行全面内容 --> AI-Compass 致力于构建最全面、最实用、最前沿的AI技术学习和实践生态,通过六大核心模块的系统化组织,为不同层次的学习者和开发者提供完整学习路径。 🌟 如果本项目对您有所帮助,请为我们点亮一颗星!🌟必知必会:大模型训练通信开销计算详解与面试指南
为什么通信开销很重要?
符号约定
符号 含义 Φ 模型参数量(FP16精度下占用2Φ字节) b 批次大小(batch size) s 序列长度(sequence length) h 隐藏层维度(hidden dimension) N GPU数量 1. 集合通信原语详解
1.1 核心问题
1.2 原文核心要点
在分布式训练过程中,不同的 GPU 之间可以通过集合通信原语来传递模型参数、梯度等信息。核心原语包括 Broadcast、Scatter、Gather、Reduce、All-Gather、Reduce-Scatter、All-Reduce 和 All-to-All 八种,它们是所有并行策略的通信基石。
1.3 通俗理解
直观类比
核心要点
建立了直觉之后,下面我们用具体的示意图和对比表来严格定义这些通信原语。
1.4 八大通信原语
1. Broadcast(广播)
操作前: 操作后:
┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐
│ A │ │ │ │ → │ A │ A │ A │ A │
└───┴───┴───┴───┘ └───┴───┴───┴───┘
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU32. Scatter(分发)
操作前: 操作后:
┌─────────────┐ ┌───┬───┬───┬───┐
│ A │ B │ C │ D │ → │ A │ B │ C │ D │
└─────────────┘ └───┴───┴───┴───┘
GPU0 GPU0 GPU1 GPU2 GPU33. Gather(收集)
操作前: 操作后:
┌───┬───┬───┬───┐ ┌─────────────┐
│ A │ B │ C │ D │ → │ A │ B │ C │ D │ (其他GPU为空)
└───┴───┴───┴───┘ └─────────────┘
GPU0 GPU1 GPU2 GPU3 GPU04. Reduce(规约)
操作前: 操作后:
┌───┬───┬───┬───┐ ┌───────────────┐
│ A │ B │ C │ D │ → │ A+B+C+D │ (其他GPU为空)
└───┴───┴───┴───┘ └───────────────┘
GPU0 GPU1 GPU2 GPU3 GPU05. All-Gather(全收集)
操作前: 操作后:
┌───┬───┬───┬───┐ ┌─────────────┬─────────────┬─────────────┬─────────────┐
│ A │ B │ C │ D │ → │ A,B,C,D │ A,B,C,D │ A,B,C,D │ A,B,C,D │
└───┴───┴───┴───┘ └─────────────┴─────────────┴─────────────┴─────────────┘
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU36. Reduce-Scatter(规约分发)
操作前: 操作后:
GPU0: [A0,A1,A2,A3] GPU0: [A0+B0+C0+D0]
GPU1: [B0,B1,B2,B3] → GPU1: [A1+B1+C1+D1]
GPU2: [C0,C1,C2,C3] GPU2: [A2+B2+C2+D2]
GPU3: [D0,D1,D2,D3] GPU3: [A3+B3+C3+D3]7. All-Reduce(全规约)
操作前: 操作后:
┌───┬───┬───┬───┐ ┌─────────┬─────────┬─────────┬─────────┐
│ A │ B │ C │ D │ → │ A+B+C+D │ A+B+C+D │ A+B+C+D │ A+B+C+D │
└───┴───┴───┴───┘ └─────────┴─────────┴─────────┴─────────┘
GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU38. All-to-All(全交换)
操作前: 操作后:
GPU0: [A0,A1,A2,A3] GPU0: [A0,B0,C0,D0]
GPU1: [B0,B1,B2,B3] → GPU1: [A1,B1,C1,D1]
GPU2: [C0,C1,C2,C3] GPU2: [A2,B2,C2,D2]
GPU3: [D0,D1,D2,D3] GPU3: [A3,B3,C3,D3]1.5 通信原语对比表

<!-- trick-image:end idx=1 -->原语 发送方 接收方 是否计算 典型应用 Broadcast 1 N 否 参数初始化 Scatter 1 N 否 数据分片 Gather N 1 否 结果收集 Reduce N 1 是(规约) 梯度聚合到主节点 All-Gather N N 否 ZeRO-3参数恢复 Reduce-Scatter N N 是(规约) ZeRO梯度聚合 All-Reduce N N 是(规约) 数据并行梯度同步 All-to-All N N 否 专家并行 1.6 小结
维度 说明 原语总数 8 种基础通信原语 核心原语 All-Reduce = Reduce-Scatter + All-Gather 通信模式 1→N、N→1、N→N 三种基本模式 选型依据 是否需要计算(规约)、是否需要全局同步 关键应用 DP 用 All-Reduce,ZeRO 用 RS+AG,EP 用 All-to-All 2. 数据并行的通信开销计算
2.1 核心问题
2.2 原文核心要点
数据并行策略中,每个 GPU 保存完整模型副本,训练数据被拆分。各 GPU 计算梯度后通过 All-Reduce 同步。Ring-All-Reduce 通过环形拓扑将通信量均匀分配到每个 GPU,避免单点瓶颈。单卡通信量为 2×(N-1)×Φ/N ≈ 2Φ。
2.3 通俗理解
直观类比
核心要点
建立了直觉之后,下面我们用数学公式来严格定义通信量的计算方法,并给出具体的数值示例。
2.4 技术原理与公式推导
工作原理
两种 All-Reduce 实现方式
方式一:朴素 All-Reduce
C_{\text{naive}} = 2 \times N \times \Phi
$$符号 含义 示例值 $C_{\text{naive}}$ 朴素 All-Reduce 系统总通信量 112 GB $N$ GPU 数量 4 $\Phi$ 模型参数量(FP16 下占 2Φ 字节) 7B
C_{\text{naive}} = 2 \times 4 \times 7\text{B} \times 2\text{B/param} = 112 \text{ GB}
$$指标 朴素 All-Reduce GPU0(收集节点)通信量 (4-1) × 7B × 2B = 42 GB GPU1/2/3 各自通信量 7B × 2B = 14 GB 系统总通信量 112 GB 瓶颈 GPU0 承载 42 GB,其他仅 14 GB 方式二:Ring-All-Reduce(环形全规约)
C_{\text{Ring}} = 2 \times \frac{N-1}{N} \times \Phi \approx 2\Phi \quad (\text{当 } N \text{ 较大时})
$$符号 含义 示例值 $C_{\text{Ring}}$ Ring-All-Reduce 单卡通信量 21 GB $N$ GPU 数量 4 $\Phi$ 模型参数量 7B $\frac{\Phi}{N}$ 每次传输的数据片段大小 1.75B
C_{\text{Ring}} = \underbrace{(N-1) \times \frac{\Phi}{N}}_{\text{Reduce-Scatter}} + \underbrace{(N-1) \times \frac{\Phi}{N}}_{\text{All-Gather}} = 2 \times \frac{(N-1) \times \Phi}{N}
$$
C_{\text{Ring,total}} = N \times C_{\text{Ring}} = 2 \times (N-1) \times \Phi \approx 2N\Phi
$$
C_{\text{Ring}} = 2 \times \frac{4-1}{4} \times 7\text{B} \times 2\text{B/param} = 2 \times \frac{3}{4} \times 14\text{ GB} = 21 \text{ GB}
$$指标 Ring-All-Reduce 每卡每次传输量 7B/4 × 2B = 3.5 GB 每卡 Reduce-Scatter 3 × 3.5 GB = 10.5 GB 每卡 All-Gather 3 × 3.5 GB = 10.5 GB 每卡总通信量 21 GB(均匀分布) 系统总通信量 4 × 21 GB = 84 GB 2.5 两种方式对比

<!-- trick-image:end idx=2 -->对比项 朴素All-Reduce Ring-All-Reduce 总通信量 2 × N × Φ 2 × N × Φ 单GPU通信量分布 不均衡(收集节点负载重) 均衡(每个GPU约2Φ) 瓶颈 收集GPU的带宽 无单点瓶颈 适用场景 小规模集群 多机多卡分布式系统 关键结论:Ring-All-Reduce的总通信量与朴素All-Reduce接近,但每个GPU上的通信量更为均衡,从而缓解了通信瓶颈问题。
2.6 小结
维度 说明 通信操作 All-Reduce(梯度同步) 单卡通信量 $2 \times \frac{N-1}{N} \times \Phi \approx 2\Phi$ 系统总通信量 $2 \times N \times \Phi$ Ring 优势 负载均衡,无单点瓶颈 通信时机 仅在反向传播完成后 3. 张量并行的通信开销计算
3.1 核心问题
3.2 原文核心要点
张量并行将模型参数切分成多个参数块放到不同 GPU 上独立计算,最后聚合结果。Megatron-LM 将 FFN 层的第一个矩阵按列切分、第二个按行切分,前向和反向传播各需一次 All-Reduce。一个完整 Transformer 层的总通信量为 8×b×s×h。
3.3 通俗理解
直观类比
核心要点
建立了直觉之后,下面我们用数学公式来严格推导张量并行的通信量。
3.4 技术原理与公式推导

<!-- trick-image:end idx=3 -->Megatron-LM 切分方法
前馈神经网络层的通信开销
Y = GELU(XA) # 第一个线性层 + 激活函数
Z = Dropout(YB) # 第二个线性层 + Dropout前向传播:
┌─────────────────────────────────────────────────────────────┐
│ 输入X → [复制到各GPU] → GPU₁计算XA₁ → GELU → Y₁B₁ → Z₁ │
│ → GPU₂计算XA₂ → GELU → Y₂B₂ → Z₂ │
│ ↓ │
│ [All-Reduce聚合Z] │
└─────────────────────────────────────────────────────────────┘
反向传播:
┌─────────────────────────────────────────────────────────────┐
│ ∂L/∂Z → [分发到各GPU] → GPU₁计算梯度 │
│ → GPU₂计算梯度 │
│ ↓ │
│ [All-Reduce聚合∂L/∂X] │
└─────────────────────────────────────────────────────────────┘
C_{\text{AllReduce}} = 2 \times b \times s \times h
$$
C_{\text{FFN}} = 2 \times C_{\text{AllReduce}} = 4 \times b \times s \times h
$$符号 含义 示例值 $C_{\text{FFN}}$ FFN 层总通信量 256 MB $b$ 批次大小 4 $s$ 序列长度 2048 $h$ 隐藏维度 4096
C_{\text{FFN}} = 4 \times 4 \times 2048 \times 4096 \times 2\text{B} = 256 \text{ MB}
$$参数 值 batch size (b) 4 序列长度 (s) 2048 隐藏维度 (h) 4096 FFN 层每次 All-Reduce 数据量 2 × 4 × 2048 × 4096 × 2B = 128 MB FFN 层总通信量 4 × 4 × 2048 × 4096 × 2B = 256 MB 多头注意力层的通信开销
C_{\text{Attn}} = 4 \times b \times s \times h
$$完整 Transformer 层通信开销
C_{\text{layer}} = C_{\text{FFN}} + C_{\text{Attn}} = 4bsh + 4bsh = 8 \times b \times s \times h
$$符号 含义 示例值 $C_{\text{layer}}$ 单个 Transformer 层通信量 512 MB $C_{\text{FFN}}$ FFN 层通信量 256 MB $C_{\text{Attn}}$ 注意力层通信量 256 MB
C_{\text{layer}} = 8 \times 4 \times 2048 \times 4096 \times 2\text{B} = 512 \text{ MB}
$$组件 通信量 FFN 层 4 × b × s × h = 256 MB 注意力层 4 × b × s × h = 256 MB 单层合计 512 MB LLaMA-7B 共 32 层 32 × 512 MB = 16 GB 3.5 小结
维度 说明 切分方式 A 按列切分,B 按行切分 FFN 层通信量 $4 \times b \times s \times h$ 注意力层通信量 $4 \times b \times s \times h$ 单层总通信量 $8 \times b \times s \times h$ 通信时机 每层前向 + 反向各 2 次 All-Reduce 适用条件 机内高带宽互联 4. 流水线并行的通信开销计算
4.1 核心问题
4.2 原文核心要点
流水线并行将模型的不同层拆分到不同 GPU 上,前向传播时逐层传递激活值,反向传播时逐层回传梯度。总通信量为 2×(N-1)×b×s×h,通信方式为点对点(P2P)传输。GPipe 通过微批次(micro-batch)机制提高 GPU 利用率。
4.3 通俗理解
直观类比
核心要点
建立了直觉之后,下面我们用数学公式来严格推导流水线并行的通信量。
4.4 技术原理与公式推导
工作原理
流水线并行模型示意图:
模型输出计算损失L
↑
┌─────────────────────────────────────┐
│ 模型最后一层 │ GPU_N │ ←─┐
├─────────────────────────────────────┤ │
│ ... │ ... │ │ 反向传播
├─────────────────────────────────────┤ │ ∂L/∂Z_N-1, ..., ∂L/∂Z_1
│ 模型层2 │ GPU_2 │ │
├─────────────────────────────────────┤ │
│ 模型层1 │ GPU_1 │ ←─┘
└─────────────────────────────────────┘
↑ │
输入 │ 前向传播
MiniBatch │ Z_1, Z_2, ..., Z_N-1
↓ ↓
[MicroBatch₁] [MicroBatch₂] ... [MicroBatch_n]GPipe 方法与通信开销
C_{\text{forward}} = (N-1) \times b \times s \times h
$$
C_{\text{backward}} = (N-1) \times b \times s \times h
$$
C_{\text{PP}} = C_{\text{forward}} + C_{\text{backward}} = 2 \times (N-1) \times b \times s \times h
$$符号 含义 示例值 $C_{\text{PP}}$ 流水线并行总通信量 384 MB $N$ GPU 数量(流水线级数) 4 $N-1$ 切分点数量 3 $b \times s \times h$ 单个中间激活值大小 64 MB
C_{\text{PP}} = 2 \times (4-1) \times 4 \times 2048 \times 4096 \times 2\text{B} = 2 \times 3 \times 64\text{ MB} = 384 \text{ MB}
$$参数 值 batch size (b) 4 序列长度 (s) 2048 隐藏维度 (h) 4096 流水线切分点 N-1 = 3 个 每个中间结果大小 4 × 2048 × 4096 × 2B = 64 MB 前向通信量 3 × 64 MB = 192 MB 反向通信量 3 × 64 MB = 192 MB 总通信量 384 MB 4.5 小结
维度 说明 切分方式 按层切分到不同 GPU 通信方式 点对点(P2P)传输 总通信量 $2 \times (N-1) \times b \times s \times h$ 优化方法 GPipe micro-batch 减少气泡 优势 通信量小,适合跨机 劣势 有气泡时间,GPU 利用率受限 5. ZeRO 优化技术的通信开销计算
5.1 核心问题
5.2 原文核心要点
ZeRO(Zero Redundancy Optimizer)分三个级别:ZeRO-1 切分优化器状态,ZeRO-2 额外切分梯度,ZeRO-3 进一步切分模型参数。ZeRO-1/2 单卡通信量均为 2Φ(与数据并行相同),ZeRO-3 为 3Φ(增加 50%),但换取了与 GPU 数量成正比的显存节省。
5.3 通俗理解
直观类比
核心要点
建立了直觉之后,下面我们用数学公式来严格分析 ZeRO 各级别的通信开销。
5.4 技术原理与公式推导
ZeRO-1:优化器状态分片
C_{\text{ZeRO-1}} = \underbrace{\Phi}_{\text{Reduce-Scatter}} + \underbrace{\Phi}_{\text{All-Gather}} = 2\Phi
$$ZeRO-2:优化器状态 + 梯度分片
C_{\text{ZeRO-2}} = \underbrace{\Phi}_{\text{Reduce-Scatter}} + \underbrace{\Phi}_{\text{All-Gather}} = 2\Phi
$$ZeRO-3:优化器状态 + 梯度 + 参数分片
C_{\text{ZeRO-3}} = \underbrace{\Phi}_{\text{前向 All-Gather}} + \underbrace{\Phi}_{\text{反向 All-Gather}} + \underbrace{\Phi}_{\text{Reduce-Scatter}} = 3\Phi
$$5.5 ZeRO 通信开销对比

<!-- trick-image:end idx=4 -->版本 分片内容 单卡通信量 显存节省 ZeRO-1 优化器状态 $2\Phi$ 约4倍 ZeRO-2 优化器状态 + 梯度 $2\Phi$ 约8倍 ZeRO-3 优化器状态 + 梯度 + 参数 $3\Phi$ 与GPU数量成正比 关键结论:ZeRO-3通过更激进的分片策略换取更大的显存节省,但代价是增加了50%的通信开销(从2Φ增加到3Φ)。
\frac{C_{\text{ZeRO-3}} - C_{\text{ZeRO-2}}}{C_{\text{ZeRO-2}}} = \frac{3\Phi - 2\Phi}{2\Phi} = 50\%
$$
C_{\text{ZeRO-1}} = C_{\text{ZeRO-2}} = 2 \times 7\text{B} \times 2\text{B/param} = 28 \text{ GB}
$$
C_{\text{ZeRO-3}} = 3 \times 7\text{B} \times 2\text{B/param} = 42 \text{ GB}
$$版本 单卡通信量 具体数值 单卡显存占用(优化器+梯度+参数) 数据并行 $2\Phi$ 28 GB 112 GB(放不下) ZeRO-1 $2\Phi$ 28 GB ~42 GB ZeRO-2 $2\Phi$ 28 GB ~28 GB ZeRO-3 $3\Phi$ 42 GB ~14 GB(可放入 A100-80G) 5.6 小结
维度 说明 核心思想 切分冗余存储,降低单卡显存占用 ZeRO-1/2 通信量 $2\Phi$(与 DP 相同) ZeRO-3 通信量 $3\Phi$(增加 50%) 显存节省 ZeRO-1 ~4×,ZeRO-2 ~8×,ZeRO-3 ~N× 通信操作 Reduce-Scatter + All-Gather(ZeRO-3 额外 2×All-Gather) 6. 通信开销综合对比与选型指南
6.1 各并行策略通信量汇总
并行策略 单卡通信量 主要通信操作 通信发生时机 数据并行 $2\Phi$ All-Reduce 反向传播后 张量并行 $8 \times b \times s \times h$(每层) All-Reduce 前向+反向传播中 流水线并行 $2 \times (N-1) \times b \times s \times h$ P2P 层间传递 ZeRO-1 $2\Phi$ Reduce-Scatter + All-Gather 反向传播后 ZeRO-2 $2\Phi$ Reduce-Scatter + All-Gather 反向传播后 ZeRO-3 $3\Phi$ All-Gather×2 + Reduce-Scatter 前向+反向+更新 6.2 选型决策流程

<!-- trick-image:end idx=5 -->6.3 3D 并行配置示例
6.4 3D 并行通信开销
C_{\text{TP}} = 8 \times b \times s \times h \quad \text{(每层,All-Reduce 激活值)}
$$
C_{\text{PP}} = 2 \times (D_{\text{pp}} - 1) \times b \times s \times h \quad \text{(P2P 传递激活值)}
$$
C_{\text{DP}} = \frac{2\Phi}{D_{\text{tp}} \times D_{\text{pp}}} \quad \text{(All-Reduce 梯度)}
$$符号 含义 示例值 $C_{\text{TP}}$ 张量并行每层通信量 512 MB $C_{\text{PP}}$ 流水线并行总通信量 384 MB $C_{\text{DP}}$ 数据并行梯度通信量 1.75 GB $D_{\text{tp}} \times D_{\text{pp}}$ 单个 DP 组内的 GPU 数 16 并行维度 通信操作 通信量 互联要求 张量并行 All-Reduce 激活值 $8 \times b \times s \times h$(每层) NVLink(机内) 流水线并行 P2P 传递激活值 $2 \times (D_{\text{pp}}-1) \times b \times s \times h$ 机内或跨机 数据并行 All-Reduce 梯度 $\frac{2\Phi}{D_{\text{tp}} \times D_{\text{pp}}}$ 跨机网络 6.5 通信优化方法
优化方法 原理 效果 通信与计算重叠 反向传播时已完成层提前通信 隐藏通信延迟 梯度累积 多个 micro-batch 累积后再同步 减少通信频率 梯度压缩 FP16→INT8 量化或 Top-K 稀疏化 减少通信数据量 Bucket 融合 多个小张量合并成大张量再通信 降低通信启动开销 分层 All-Reduce 先机内 Reduce 再跨机 All-Reduce 减少跨机流量 拓扑感知调度 Ring/2D-Torus 匹配物理拓扑 提高带宽利用率 7. 高频面试题及答案
Q1: 请解释 All-Reduce 和 Ring-All-Reduce 的区别?(基础)
两者目标相同(让所有 GPU 拥有完整规约结果),但实现方式不同。朴素 All-Reduce 由一个节点收集聚合再广播,存在单点瓶颈;Ring-All-Reduce 通过环形拓扑分 Reduce-Scatter 和 All-Gather 两阶段完成,每卡通信量均为 2×(N-1)×Φ/N ≈ 2Φ,负载均衡无瓶颈。要点 说明 朴素方式 收集节点承载 (N-1)×Φ 通信量,成为瓶颈 Ring 方式 数据切 N 份,环形传递 N-1 轮,每轮传 Φ/N 总通信量 两者相近,约 2×N×Φ 关键区别 Ring 方式负载均匀分布,消除单点瓶颈 工程实践 几乎所有分布式框架默认使用 Ring-All-Reduce
C_{\text{Ring}} = 2 \times \frac{N-1}{N} \times \Phi \approx 2\Phi
$$Q2: 数据并行和张量并行的通信开销有什么区别?(基础)
数据并行通信量为 2Φ,仅在反向传播后同步梯度;张量并行通信量为 8×b×s×h(每层),在每层的前向和反向传播中都需要通信。前者与模型参数量成正比,后者与激活值大小成正比。对比维度 数据并行 张量并行 通信量 $2\Phi$(与模型大小相关) $8bsh$(与激活值大小相关) 通信时机 反向传播完成后 每层的前向和反向传播中 通信频率 每个训练步一次 每层都需要通信 通信操作 All-Reduce(梯度) All-Reduce(激活值) 适用场景 模型能放进单卡 模型太大需要层内切分 Q3: ZeRO-1、ZeRO-2、ZeRO-3 的区别及通信开销?(进阶)
ZeRO-1 切分优化器状态,ZeRO-2 额外切分梯度,ZeRO-3 进一步切分参数。ZeRO-1/2 单卡通信量为 2Φ(与 DP 相同),ZeRO-3 为 3Φ(增加 50%),但显存节省与 GPU 数量成正比。版本 分片内容 各GPU存储 单卡通信量 显存节省 ZeRO-1 优化器状态 完整模型+完整梯度+部分优化器 $2\Phi$ ~4× ZeRO-2 优化器状态+梯度 完整模型+部分梯度+部分优化器 $2\Phi$ ~8× ZeRO-3 全部 部分模型+部分梯度+部分优化器 $3\Phi$ ~N×
C_{\text{ZeRO-1}} = C_{\text{ZeRO-2}} = 2\Phi, \quad C_{\text{ZeRO-3}} = 3\Phi
$$Q4: 流水线并行的通信开销公式是什么?如何理解?(进阶)
公式为 2×(N-1)×b×s×h。N-1 是切分点数,b×s×h 是激活值大小,×2 因为前向和反向各传一次。通信方式为 P2P,仅相邻 GPU 之间传递。要点 说明 N-1 模型切到 N 张 GPU 上,有 N-1 个切分点 b×s×h 中间激活值大小(batch × sequence × hidden) ×2 前向传激活值 + 反向传梯度 通信方式 点对点(P2P),非集合通信 GPipe 优化 micro-batch 提高利用率,但不改变总通信量
C_{\text{PP}} = 2 \times (N-1) \times b \times s \times h
$$Q5: 什么是集合通信原语?请列举并解释主要原语(基础)
集合通信原语是分布式系统中多节点间数据传输的基本操作,共 8 种。按模式分为 1→N(Broadcast/Scatter)、N→1(Gather/Reduce)、N→N(All-Gather/Reduce-Scatter/All-Reduce/All-to-All)三类。原语 模式 是否规约 典型应用 Broadcast 1→N 否 参数初始化 Scatter 1→N 否 数据划分 Gather N→1 否 结果汇总 Reduce N→1 是 梯度汇总到主节点 All-Gather N→N 否 ZeRO-3 参数恢复 Reduce-Scatter N→N 是 ZeRO 梯度聚合 All-Reduce N→N 是 DP 梯度同步 All-to-All N→N 否 专家并行(MoE) Q6: 在万卡集群中,为什么说"优化器状态可以忽略不计"?(进阶)
在 3D 并行(DP×TP×PP)环境下,优化器状态被所有 GPU 分担。1024 卡集群中,优化器状态每卡只需存储 12Φ/1024 ≈ 0.012Φ,相比参数和梯度的 2Φ/(TP×PP) 占比极小。要点 说明 单卡显存公式 参数: 2Φ/(TP×PP) + 梯度: 2Φ/(TP×PP) + 优化器: 12Φ/(DP×TP×PP) 关键区别 参数和梯度被 TP×PP 份切分,优化器被所有 GPU 切分 1024卡实例 优化器每卡 0.012Φ,参数每卡约 0.125Φ(以 TP=4,PP=4 为例) 实际意义 万卡集群下无需特别优化优化器显存,重点关注激活值和通信
M_{\text{per\_gpu}} = \frac{2\Phi}{D_{\text{tp}} \times D_{\text{pp}}} + \frac{2\Phi}{D_{\text{tp}} \times D_{\text{pp}}} + \frac{12\Phi}{D_{\text{dp}} \times D_{\text{tp}} \times D_{\text{pp}}}
$$Q7: 张量并行中,前馈神经网络层为什么需要 2 次 All-Reduce?(进阶)
FFN 层包含两步线性变换(Y=GELU(XA), Z=Dropout(YB)),A 按列切分、B 按行切分。前向传播需 All-Reduce 聚合 Z,反向传播需 All-Reduce 聚合 ∂L/∂X,各一次共 2 次,通信量 4×b×s×h。要点 说明 切分方式 A 按列切分 [A₁, A₂],B 按行切分 [B₁; B₂] 前向 All-Reduce 聚合各 GPU 计算的 Z₁、Z₂ 得到完整输出 Z 反向 All-Reduce 聚合各 GPU 计算的 ∂L/∂X₁、∂L/∂X₂ 得到完整梯度 每次数据量 2×b×s×h(All-Reduce 等效于 RS+AG) 总通信量 4×b×s×h
C_{\text{FFN}} = 2 \times C_{\text{AllReduce}} = 2 \times 2bsh = 4bsh
$$Q8: 比较数据并行和 ZeRO 的通信开销(进阶)
ZeRO-1/2 通信量为 2Φ,与数据并行完全相同,因为 Reduce-Scatter + All-Gather = All-Reduce。ZeRO-3 通信量为 3Φ,多出的 1Φ 来自前向和反向传播各一次 All-Gather 收集参数。方法 通信量 通信操作 显存效果 数据并行 $2\Phi$ All-Reduce 无优化 ZeRO-1 $2\Phi$ Reduce-Scatter + All-Gather 节省 ~4× ZeRO-2 $2\Phi$ Reduce-Scatter + All-Gather 节省 ~8× ZeRO-3 $3\Phi$ 2×All-Gather + Reduce-Scatter 节省 ~N× Q9: 3D 并行是什么?如何计算其通信开销?(综合)
3D 并行组合数据并行(DP)、张量并行(TP)、流水线并行(PP)。TP 使用机内 NVLink 高带宽,PP 用于层间传递,DP 用于机间梯度同步。并行维度 通信操作 通信量 互联方式 TP(张量) All-Reduce 激活值 $8bsh$/层 NVLink(机内) PP(流水线) P2P 传递激活值 $2(D_{\text{pp}}-1)bsh$ 机内/跨机 DP(数据) All-Reduce 梯度 $\frac{2\Phi}{D_{\text{tp}} \times D_{\text{pp}}}$ 跨机网络 Q10: 如何优化大模型训练的通信效率?(进阶)
六大优化手段:通信计算重叠、梯度累积减少频率、梯度压缩减少数据量、Bucket 融合降低启动开销、分层 All-Reduce 减少跨机流量、拓扑感知调度匹配物理拓扑。优化手段 原理 效果 通信计算重叠 已完成层提前通信 隐藏延迟 梯度累积 多 micro-batch 累积后同步 通信频率降低 k 倍 梯度压缩 FP16→INT8 或 Top-K 通信量减少 50-90% Bucket 融合 小张量合并通信 降低启动开销 分层 All-Reduce 先机内后机间 减少跨机流量 异步通信 流水线隐藏延迟 提升吞吐量 Q11: Ring-All-Reduce 的通信步数和带宽利用率如何计算?(进阶)
Ring-All-Reduce 需要 2×(N-1) 步通信,每步传输 Φ/N 数据。在理想情况下,所有 GPU 在每一步都同时收发数据,带宽利用率接近 100%。延迟与 N 成正比,但单步传输量与 N 成反比,总传输量与 N 无关。指标 公式 说明 通信步数 $2 \times (N-1)$ RS 阶段 N-1 步 + AG 阶段 N-1 步 每步传输量 $\frac{\Phi}{N}$ 数据被切成 N 份 总传输量(单卡) $2 \times \frac{N-1}{N} \times \Phi$ 与 N 无关(N 大时约 2Φ) 带宽利用率 接近 100% 所有 GPU 同时收发 延迟 $O(N)$ 步数与 N 成正比
T_{\text{Ring}} = 2(N-1) \times \left( \alpha + \frac{\Phi}{N \times \beta} \right)
$$Q12: 如果同时使用 ZeRO-3 和张量并行,通信开销如何叠加?(综合)
两者的通信是独立叠加的。ZeRO-3 的 3Φ 通信量针对模型参数的收集和梯度聚合(跨数据并行组),张量并行的 8bsh/层通信量针对激活值的聚合(机内 TP 组)。实际中 ZeRO-3 通常与 DP 配合,而非与 TP 同时使用,因为 TP 本身已经切分了参数。场景 ZeRO-3 通信 TP 通信 合理性 ZeRO-3 + DP $3\Phi$(跨 DP 组) 无 常见配置,适合中小模型 TP + DP 无 $8bsh$/层(机内) + $2\Phi$(跨 DP 组) 常见配置,适合大模型 ZeRO-3 + TP $\frac{3\Phi}{D_{\text{tp}}}$ + $8bsh$/层 较少使用,通信量大 通常不推荐 3D 并行 + ZeRO-1 $2\Phi/(D_{\text{tp}} \times D_{\text{pp}})$ $8bsh$/层 + PP 通信 工业界主流 8. 大厂常见面试题
Q13: 在实际训练中,如何判断通信是否成为瓶颈?有哪些诊断方法?
通信瓶颈的核心判断指标是"计算开销/通信开销"的比值(Computation-to-Communication Ratio)。当通信时间占训练步时间的比例超过 30%,通常认为通信已成为瓶颈。诊断方法 工具 关键指标 NCCL 日志分析 NCCL_DEBUG=INFO All-Reduce 耗时、带宽利用率 Profiler 火焰图 PyTorch Profiler / Nsight 通信算子在时间线中的占比 吞吐量对比 实际 vs 理论 MFU 理论算力利用率 < 50% 时怀疑通信瓶颈 梯度累积实验 增大累积步数 如果吞吐量提升明显,说明通信是瓶颈 带宽测试 NCCL Tests (all_reduce_perf) 实测带宽 vs 理论带宽 Q14: DeepSpeed ZeRO 与 FSDP(PyTorch FullyShardedDataParallel)有什么异同?
两者核心思想一致——都将优化器状态、梯度和参数进行分片以节省显存。FSDP 是 PyTorch 原生实现,API 与 DDP 相似,易于迁移;DeepSpeed ZeRO 提供更细粒度的配置(Stage 1/2/3)和丰富的工程优化(offload、infinity 等)。对比维度 DeepSpeed ZeRO PyTorch FSDP 框架 微软独立框架 PyTorch 原生 分片粒度 Stage 1/2/3 分级控制 类似 ZeRO-3,统一分片 CPU Offload 支持(ZeRO-Infinity) 支持(cpu_offload) 通信后端 NCCL / 自定义 NCCL 混合精度 FP16/BF16/FP8 FP16/BF16 社区生态 HuggingFace Accelerate 集成 PyTorch 原生,兼容性最好 Q15: 训练 70B 参数量的模型,如何设计并行策略和通信方案?请给出具体配置。
以 128 张 A100-80G 为例,推荐采用 3D 并行:TP=8(机内 NVLink),PP=2(跨机),DP=8(机间),配合 ZeRO-1 切分优化器状态。配置项 值 理由 张量并行 TP 8 一台机器 8 卡,NVLink 互联带宽 900 GB/s 流水线并行 PP 2 70B/8(TP)=8.75B/卡,2 级流水线后约 4.4B/卡,显存可控 数据并行 DP 8 128/(8×2)=8,提升训练吞吐量 ZeRO Stage 1 切分优化器状态,通信量不增加 梯度累积 4-8 步 减少 DP 通信频率 序列长度 4096 平衡显存和通信
C_{\text{TP}} = 8 \times b \times 4096 \times 8192 \times 2\text{B} \approx 4\text{ GB/层}
$$
C_{\text{PP}} = 2 \times 1 \times b \times 4096 \times 8192 \times 2\text{B} \approx 0.5\text{ GB}
$$
C_{\text{DP}} = \frac{2 \times 70\text{B} \times 2\text{B}}{8 \times 2} = 17.5\text{ GB}
$$总结
核心知识点回顾
知识点 核心公式/结论 关键理解 集合通信原语 8 种基础原语 All-Reduce = RS + AG 是最核心的组合 数据并行 $C = 2\Phi$ Ring 方式消除单点瓶颈,负载均衡 张量并行 $C = 8bsh$/层 每层 4 次 All-Reduce,适合机内高带宽 流水线并行 $C = 2(N-1)bsh$ P2P 通信,量小但有气泡问题 ZeRO-1/2 $C = 2\Phi$ 通信量与 DP 相同,显存分别省 4×/8× ZeRO-3 $C = 3\Phi$ +50% 通信换取 N× 显存节省 3D 并行 TP+PP+DP 组合 TP 机内 → PP 跨层 → DP 机间 通信优化 6 大手段 重叠、累积、压缩、融合、分层、拓扑 思维导图结构
大模型训练通信开销
├── 集合通信原语
│ ├── 1→N:Broadcast / Scatter
│ ├── N→1:Gather / Reduce
│ └── N→N:All-Gather / Reduce-Scatter / All-Reduce / All-to-All
├── 数据并行 (单卡 2Φ)
│ ├── 朴素 All-Reduce:单点瓶颈
│ └── Ring-All-Reduce:均衡负载 = RS + AG
├── 张量并行 (单层 8bsh)
│ ├── FFN:A 列切 + B 行切 → 4bsh
│ ├── Attention → 4bsh
│ └── 适合机内 NVLink
├── 流水线并行 (2(N-1)bsh)
│ ├── 按层切分,P2P 传递
│ ├── GPipe micro-batch 减少气泡
│ └── 适合跨机通信
├── ZeRO 优化
│ ├── ZeRO-1:切分优化器 → 2Φ
│ ├── ZeRO-2:+切分梯度 → 2Φ
│ └── ZeRO-3:+切分参数 → 3Φ
├── 3D 并行
│ ├── TP 机内 (NVLink)
│ ├── PP 跨层 (P2P)
│ └── DP 机间 (All-Reduce)
└── 通信优化
├── 重叠 / 累积 / 压缩
└── 融合 / 分层 / 拓扑参考文献