基类与自定义引擎¶
权重传输系统基于一个抽象基类构建,该类定义了 vLLM 的工作节点基础设施与传输后端之间的契约。你可以通过继承 WeightTransferEngine 并将其注册到 WeightTransferEngineFactory 来实现自定义后端。
WeightTransferEngine¶
WeightTransferEngine 是一个泛型抽象类,由两种数据类类型参数化:
TInitInfo(扩展自WeightTransferInitInfo):后端特定的初始化参数。TUpdateInfo(扩展自WeightTransferUpdateInfo):后端特定的权重更新元数据。
抽象方法¶
子类必须实现以下四个方法:
| 方法 | 端点 | 描述 |
|---|---|---|
init_transfer_engine(init_info) | 推理端 | 在每个推理工作节点上初始化通信通道 |
receive_weights(update_info, load_weights) | 推理端 | 接收权重并以增量方式调用 load_weights |
shutdown() | 推理端 | 清理资源 |
trainer_send_weights(iterator, trainer_args) | 训练端 | 从训练进程发送权重的静态方法 |
请求类¶
API 层的请求类使用普通字典提供与后端无关的序列化方式。引擎的 parse_init_info 和 parse_update_info 方法会将这些字典转换为类型化的数据类。
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
# Init request (dict is converted to backend-specific TInitInfo)
init_request = WeightTransferInitRequest(
init_info={"master_address": "10.0.0.1", "master_port": 29500, ...}
)
# Update request (dict is converted to backend-specific TUpdateInfo)
update_request = WeightTransferUpdateRequest(
update_info={"names": [...], "dtype_names": [...], "shapes": [...]}
)
WeightTransferUpdateInfo¶
基础 WeightTransferUpdateInfo 包含一个 is_checkpoint_format 标志。
当 is_checkpoint_format=True(默认值)时,vLLM 会在加载接收到的权重之前对其进行层级权重处理(重新打包、重命名等)。如果训练器已经将权重转换为模型预期的内核格式,请将其设置为 False。
实现自定义引擎¶
要创建自定义权重传输后端:
1. 定义信息数据类¶
from dataclasses import dataclass
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
@dataclass
class MyInitInfo(WeightTransferInitInfo):
endpoint: str
token: str
@dataclass
class MyUpdateInfo(WeightTransferUpdateInfo):
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
# Add custom fields as needed
2. 实现引擎¶
from collections.abc import Callable, Iterator
from typing import Any
import torch
class MyWeightTransferEngine(WeightTransferEngine[MyInitInfo, MyUpdateInfo]):
init_info_cls = MyInitInfo
update_info_cls = MyUpdateInfo
def init_transfer_engine(self, init_info: MyInitInfo) -> None:
# Set up connection to trainer using init_info.endpoint, etc.
...
def receive_weights(
self,
update_info: MyUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
# Receive each weight and call load_weights incrementally
for name, dtype_name, shape in zip(
update_info.names, update_info.dtype_names, update_info.shapes
):
dtype = getattr(torch, dtype_name)
weight = self._fetch_weight(name, shape, dtype)
load_weights([(name, weight)])
def shutdown(self) -> None:
# Clean up resources
...
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any],
) -> None:
# Send weights from the trainer process
for name, tensor in iterator:
# Send tensor via custom transport
...
重要
传递给 receive_weights 的 load_weights 可调用对象应增量式(一次或少量权重)调用,而不是先累积所有权重。这可以避免处理大模型时出现 GPU 显存溢出(OOM)错误。
3. 在工厂中注册¶
from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
# Option 1: Lazy loading (recommended for built-in engines)
WeightTransferEngineFactory.register_engine(
"my_backend",
"my_package.my_module",
"MyWeightTransferEngine",
)
# Option 2: Direct class registration
WeightTransferEngineFactory.register_engine(
"my_backend",
MyWeightTransferEngine,
)
注册完成后,用户可以通过 WeightTransferConfig(backend="my_backend") 选择你的后端。
WeightTransferEngineFactory¶
该工厂使用具有惰性加载功能的注册模式。内置引擎(nccl 和 ipc)在导入时注册,但只有在实际请求后端时才会加载其模块。这避免了在不需要时导入繁重的依赖项(如 NCCL 通信库)。