跳到内容

基类与自定义引擎

权重传输系统基于一个抽象基类构建,该类定义了 vLLM 的工作节点基础设施与传输后端之间的契约。你可以通过继承 WeightTransferEngine 并将其注册到 WeightTransferEngineFactory 来实现自定义后端。

WeightTransferEngine

WeightTransferEngine 是一个泛型抽象类,由两种数据类类型参数化:

抽象方法

子类必须实现以下四个方法:

方法 端点 描述
init_transfer_engine(init_info) 推理端 在每个推理工作节点上初始化通信通道
receive_weights(update_info, load_weights) 推理端 接收权重并以增量方式调用 load_weights
shutdown() 推理端 清理资源
trainer_send_weights(iterator, trainer_args) 训练端 从训练进程发送权重的静态方法

请求类

API 层的请求类使用普通字典提供与后端无关的序列化方式。引擎的 parse_init_infoparse_update_info 方法会将这些字典转换为类型化的数据类。

from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)

# Init request (dict is converted to backend-specific TInitInfo)
init_request = WeightTransferInitRequest(
    init_info={"master_address": "10.0.0.1", "master_port": 29500, ...}
)

# Update request (dict is converted to backend-specific TUpdateInfo)
update_request = WeightTransferUpdateRequest(
    update_info={"names": [...], "dtype_names": [...], "shapes": [...]}
)

WeightTransferUpdateInfo

基础 WeightTransferUpdateInfo 包含一个 is_checkpoint_format 标志。

@dataclass
class WeightTransferUpdateInfo(ABC):
    is_checkpoint_format: bool = True

is_checkpoint_format=True(默认值)时,vLLM 会在加载接收到的权重之前对其进行层级权重处理(重新打包、重命名等)。如果训练器已经将权重转换为模型预期的内核格式,请将其设置为 False

实现自定义引擎

要创建自定义权重传输后端:

1. 定义信息数据类

from dataclasses import dataclass
from vllm.distributed.weight_transfer.base import (
    WeightTransferEngine,
    WeightTransferInitInfo,
    WeightTransferUpdateInfo,
)

@dataclass
class MyInitInfo(WeightTransferInitInfo):
    endpoint: str
    token: str

@dataclass
class MyUpdateInfo(WeightTransferUpdateInfo):
    names: list[str]
    dtype_names: list[str]
    shapes: list[list[int]]
    # Add custom fields as needed

2. 实现引擎

from collections.abc import Callable, Iterator
from typing import Any
import torch

class MyWeightTransferEngine(WeightTransferEngine[MyInitInfo, MyUpdateInfo]):
    init_info_cls = MyInitInfo
    update_info_cls = MyUpdateInfo

    def init_transfer_engine(self, init_info: MyInitInfo) -> None:
        # Set up connection to trainer using init_info.endpoint, etc.
        ...

    def receive_weights(
        self,
        update_info: MyUpdateInfo,
        load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
    ) -> None:
        # Receive each weight and call load_weights incrementally
        for name, dtype_name, shape in zip(
            update_info.names, update_info.dtype_names, update_info.shapes
        ):
            dtype = getattr(torch, dtype_name)
            weight = self._fetch_weight(name, shape, dtype)
            load_weights([(name, weight)])

    def shutdown(self) -> None:
        # Clean up resources
        ...

    @staticmethod
    def trainer_send_weights(
        iterator: Iterator[tuple[str, torch.Tensor]],
        trainer_args: dict[str, Any],
    ) -> None:
        # Send weights from the trainer process
        for name, tensor in iterator:
            # Send tensor via custom transport
            ...

重要

传递给 receive_weightsload_weights 可调用对象应增量式(一次或少量权重)调用,而不是先累积所有权重。这可以避免处理大模型时出现 GPU 显存溢出(OOM)错误。

3. 在工厂中注册

from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory

# Option 1: Lazy loading (recommended for built-in engines)
WeightTransferEngineFactory.register_engine(
    "my_backend",
    "my_package.my_module",
    "MyWeightTransferEngine",
)

# Option 2: Direct class registration
WeightTransferEngineFactory.register_engine(
    "my_backend",
    MyWeightTransferEngine,
)

注册完成后,用户可以通过 WeightTransferConfig(backend="my_backend") 选择你的后端。

WeightTransferEngineFactory

该工厂使用具有惰性加载功能的注册模式。内置引擎(ncclipc)在导入时注册,但只有在实际请求后端时才会加载其模块。这避免了在不需要时导入繁重的依赖项(如 NCCL 通信库)。

from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory

# Create an engine from config
engine = WeightTransferEngineFactory.create_engine(
    config=weight_transfer_config,
    parallel_config=parallel_config,
)