标签 分布式训练 下的文章

引言

在深度学习模型日益庞大的今天,单机训练已难以满足效率需求。如何高效利用多设备(如多 GPU 或昇腾 NPU)进行分布式训练,成为工业界的核心挑战。

而 MindSpore提供了一种革命性的解决方案:自动并行(Auto Parallel)—— 开发者只需关注模型逻辑,框架自动完成数据/模型/流水线并行策略的生成与优化。配合其 动静统一的执行模式,既保留了动态图的调试灵活性,又具备静态图的高性能推理能力。

本文将带你深入这两个核心特性,并通过一个实际案例演示如何在多设备上轻松实现分布式训练。

一、动静统一:PyNative 与 Graph 模式的无缝切换

1.1 什么是动静统一?

  • PyNative 模式:类似 PyTorch,逐行执行,便于调试(支持 print、断点等)。
  • Graph 模式:将整个网络编译为计算图,执行效率高,适合部署。

MindSpore 允许你在同一个项目中自由切换两种模式:

import mindspore as ms

# 默认是 Graph 模式
ms.set_context(mode=ms.GRAPH_MODE)

# 切换到 PyNative 模式(用于调试)
ms.set_context(mode=ms.PYNATIVE_MODE)

1.2 调试技巧:先 PyNative,后 Graph

推荐开发流程:

  1. 在 PyNative 模式下编写和调试模型;
  2. 确认无误后,切换到 Graph 模式进行训练或推理,获得更高性能。
💡 注意:Graph 模式对控制流(如 if/for)有语法限制,但 MindSpore 提供了 @ms.jit和 ops.depend等机制来兼容复杂逻辑。

二、自动并行:让分布式训练“零门槛”

传统分布式训练需要手动设计数据切分、梯度同步、通信策略(如 AllReduce),代码复杂且易错。而 MindSpore 的 自动并行技术通过 策略搜索 + 图编译优化,自动生成最优并行方案。

2.1 启用自动并行的三步走

  1. 配置设备环境(如 8 卡 Ascend 或 GPU);
  2. 设置并行上下文;
  3. 使用 Model高阶 API 或手动构建训练流程。

2.2 实战:ResNet50 在 ImageNet 上的自动并行训练

以下是一个简化版的自动并行训练脚本(适用于 Ascend 910 或多 GPU):

import mindspore as ms
from mindspore import nn, Model
from mindspore.communication import init, get_rank, get_group_size
from mindspore.nn.optim import Momentum
from src.dataset import create_dataset  # 假设你有 ImageNet 数据加载器
from src.network import resnet50        # 自定义 ResNet50 网络

# 1. 初始化分布式环境
init()  # 自动检测 backend(HCCL for Ascend, NCCL for GPU)
rank_id = get_rank()
device_num = get_group_size()

# 2. 设置自动并行模式
ms.set_auto_parallel_context(
    device_num=device_num,
    parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
    gradients_mean=True
)

# 3. 构建数据集(自动按 rank 切分)
dataset = create_dataset(
    dataset_path="/path/to/imagenet",
    do_train=True,
    batch_size=32,
    device_num=device_num,
    rank=rank_id
)

# 4. 定义网络与损失
network = resnet50(class_num=1000)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = Momentum(
    network.trainable_params(),
    learning_rate=0.01,
    momentum=0.9
)

# 5. 使用 Model 高阶 API(自动处理并行逻辑)
model = Model(network, loss_fn=loss_fn, optimizer=optimizer)

# 6. 开始训练
model.train(epoch=90, train_dataset=dataset, dataset_sink_mode=True)
✅ 关键点:你不需要写任何通信代码!MindSpore 会根据硬件拓扑和模型结构,自动选择数据并行、模型并行或混合并行策略。

2.3 性能对比:自动 vs 手动并行

在华为内部测试中,ResNet50 在 8×Ascend 910 上:

  • 手动数据并行:吞吐 ~8500 images/sec
  • MindSpore 自动并行:吞吐 ~9200 images/sec(自动融合通信与计算)

这得益于其 图算融合与 通信算子自动插入技术。

三、为什么选择 MindSpore 的自动并行?

特性传统框架(如 PyTorch DDP)MindSpore Auto Parallel
编程复杂度高(需手动管理进程、同步)极低(一行配置)
并行策略仅支持数据并行支持数据/模型/流水线/混合并行
硬件适配依赖 NCCL原生优化昇腾,也支持 GPU/CPU
扩展性难以扩展到千卡已验证万卡集群训练

结语

MindSpore 不仅仅是一个“另一个深度学习框架”,它代表了一种 以编译器为中心、软硬协同的新范式。通过 自动并行和 动静统一,它大幅降低了大规模 AI 开发的门槛,尤其适合需要高性能、高可扩展性的工业场景。

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
                                 device):

    """
    vLLM 提供 `StatelessProcessGroup` 来创建进程组,
    无需考虑 torch.distributed 中的全局进程组。
    建议先创建 `StatelessProcessGroup`,然后初始化
    外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup
    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


class WorkerExtension:

    """
    vLLM 工作进程的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def init_weight_update_group(self, master_address, master_port,
                                 rank_offset, world_size):
        from vllm.distributed.parallel_state import get_world_group
        rank = get_world_group().rank + rank_offset
        self.model_update_group = stateless_init_process_group(
            master_address,
            master_port,
            rank,
            world_size,
            self.device,
        )

    def update_weight(self, name, dtype, shape):
        weight = torch.empty(shape, dtype=dtype, device="cuda")
        self.model_update_group.broadcast(weight,
                                          src=0,
                                          stream=torch.cuda.current_stream())

        self.model_runner.model.load_weights(weights=[(name, weight)])

        del weight

    def check_weights_changed(self):
        """
        Check if the weights are updated to 0.
        """
        """
        检查权重是否已更新为 0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated


class ColocateWorkerExtension:

    """
    vLLM 工作进程在协同部署场景下的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def report_device_id(self) -> str:
        from vllm.platforms import current_platform
        self.device_uuid = current_platform.get_device_uuid(self.device.index)
        return self.device_uuid

    def update_weights_from_ipc_handles(self, ipc_handles):
        handles = ipc_handles[self.device_uuid]
        device_id = self.device.index
        weights = []
        for name, handle in handles.items():
            func, args = handle
            list_args = list(args)
            # the key is to change device id to the current device id
            # in case two processes have different CUDA_VISIBLE_DEVICES
            # 关键是将设备 ID 改为当前设备 ID,
            # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES
            list_args[6] = device_id
            tensor = func(*list_args)
            weights.append((name, tensor))
        self.model_runner.model.load_weights(weights=weights)
        torch.cuda.synchronize()

    def check_weights_changed(self):

        """
        检查权重是否已更新为0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
                                 device):

    """
    vLLM 提供 `StatelessProcessGroup` 来创建进程组,
    无需考虑 torch.distributed 中的全局进程组。
    建议先创建 `StatelessProcessGroup`,然后初始化
    外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup
    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


class WorkerExtension:

    """
    vLLM 工作进程的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def init_weight_update_group(self, master_address, master_port,
                                 rank_offset, world_size):
        from vllm.distributed.parallel_state import get_world_group
        rank = get_world_group().rank + rank_offset
        self.model_update_group = stateless_init_process_group(
            master_address,
            master_port,
            rank,
            world_size,
            self.device,
        )

    def update_weight(self, name, dtype, shape):
        weight = torch.empty(shape, dtype=dtype, device="cuda")
        self.model_update_group.broadcast(weight,
                                          src=0,
                                          stream=torch.cuda.current_stream())

        self.model_runner.model.load_weights(weights=[(name, weight)])

        del weight

    def check_weights_changed(self):
        """
        Check if the weights are updated to 0.
        """
        """
        检查权重是否已更新为 0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated


class ColocateWorkerExtension:

    """
    vLLM 工作进程在协同部署场景下的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def report_device_id(self) -> str:
        from vllm.platforms import current_platform
        self.device_uuid = current_platform.get_device_uuid(self.device.index)
        return self.device_uuid

    def update_weights_from_ipc_handles(self, ipc_handles):
        handles = ipc_handles[self.device_uuid]
        device_id = self.device.index
        weights = []
        for name, handle in handles.items():
            func, args = handle
            list_args = list(args)
            # the key is to change device id to the current device id
            # in case two processes have different CUDA_VISIBLE_DEVICES
            # 关键是将设备 ID 改为当前设备 ID,
            # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES
            list_args[6] = device_id
            tensor = func(*list_args)
            weights.append((name, tensor))
        self.model_runner.model.load_weights(weights=weights)
        torch.cuda.synchronize()

    def check_weights_changed(self):

        """
        检查权重是否已更新为0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated