跳到内容

NCCL 引擎

NCCL 权重传输引擎使用 NCCL 广播操作将权重从训练器传输到推理工作节点。它支持多节点多 GPU 设置,其中训练器和推理引擎运行在不同的 GPU 上。

何时使用 NCCL

  • 训练和推理运行在不同的 GPU 上(可能跨节点)
  • 需要多个工作节点进行张量并行推理,且所有节点都需要更新后的权重
  • 需要通过 NVLink 或 InfiniBand 进行高带宽、低延迟的权重传输

工作原理

  1. 训练器和所有推理工作节点使用 StatelessProcessGroup(vLLM 独立于 torch.distributed 的组抽象)加入共享的 NCCL 进程组。
  2. 训练器将权重同时广播给所有工作节点。每个工作节点以增量方式接收并加载权重。
  3. 可选的打包张量广播通过双倍/三倍缓冲和 CUDA 流重叠,将多个小张量批处理为更大的缓冲区,以实现更高的吞吐量。此实现基于 NeMo-RL 的打包张量

初始化

NCCL 需要显式的进程组设置。训练器和推理工作节点必须商定主地址、端口和全局大小(world size)。

推理侧

from vllm.distributed.weight_transfer.base import WeightTransferInitRequest

# rank_offset accounts for the trainer occupying rank 0
llm.init_weight_transfer_engine(
    WeightTransferInitRequest(
        init_info=dict(
            master_address=master_address,
            master_port=master_port,
            rank_offset=1,
            world_size=world_size,  # trainer + all inference workers
        )
    )
)

训练侧

from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLWeightTransferEngine,
)

group = NCCLWeightTransferEngine.trainer_init(
    dict(
        master_address=master_address,
        master_port=master_port,
        world_size=world_size,
    )
)

注意

trainer_init 始终将训练器分配为 rank 0。推理工作节点从 rank_offset(通常为 1)开始。

发送权重

from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLTrainerSendWeightsArgs,
    NCCLWeightTransferEngine,
)

trainer_args = NCCLTrainerSendWeightsArgs(
    group=group,
    packed=True,  # use packed broadcasting for efficiency
)

NCCLWeightTransferEngine.trainer_send_weights(
    iterator=model.named_parameters(),
    trainer_args=trainer_args,
)

完整可配置字段列表请参阅 NCCLTrainerSendWeightsArgs

打包张量广播

packed=True 时,多个权重张量在广播前被打包成大的连续缓冲区。这减少了 NCCL 操作的数量,并使用带有专用 CUDA 流的双倍/三倍缓冲,以实现打包、广播和解包之间的重叠。

训练侧(NCCLTrainerSendWeightsArgs)和推理侧(NCCLWeightTransferUpdateInfo)必须使用匹配的 packed_buffer_size_bytespacked_num_buffers 值。

接收权重(推理侧)

推理侧通过调用 update_weights 触发权重接收。

from vllm.distributed.weight_transfer.base import WeightTransferUpdateRequest

llm.update_weights(
    WeightTransferUpdateRequest(
        update_info=dict(
            names=names,
            dtype_names=dtype_names,
            shapes=shapes,
            packed=True,
        )
    )
)

namesdtype_namesshapes 列表描述了每个参数。这些必须与训练器遍历其参数的顺序一致。

示例