跳到内容

以插件形式接入自定义 JAX 模型

本指南将引导您完成将基础 JAX 模型接入 TPU 推理的步骤。

1. 引入模型代码

本指南假设您的模型是基于 JAX 编写的。

2. 使代码兼容 vLLM

为了确保与 vLLM 的兼容性,您的模型必须满足以下要求:

初始化代码

模型内的所有 vLLM 模块必须在其构造函数中包含一个 vllm_config 参数。该参数包含所有与 vLLM 相关的配置以及模型配置。

初始化代码应如下所示:

class LlamaForCausalLM(nnx.Module):

    def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
                 mesh: Mesh) -> None:
        self.vllm_config = vllm_config
        self.rng = nnx.Rngs(rng_key)
        self.mesh = mesh

        self.model = LlamaModel(
            vllm_config=vllm_config,
            rng=self.rng,
            mesh=mesh,
        )

计算代码

模型的前向传播逻辑应位于 __call__ 中,且必须至少包含以下参数:

def __call__(
    self,
    kv_caches: List[jax.Array],
    input_ids: jax.Array,
    attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array]:
    

作为参考,请查看我们的 Llama 实现

3. 实现权重加载逻辑

现在,您需要在 *ForCausalLM 类中实现 load_weights 方法。此方法应从 HuggingFace 检查点文件(或兼容的本地检查点)加载权重,并将其分配给模型中的相应层。

4. 注册您的模型

TPU 推理依赖于模型注册表来确定如何运行每个模型。预注册架构列表可在此处找到。

如果您的模型不在列表中,则必须将其注册到 TPU 推理中。您可以使用插件(类似于 vLLM 插件)加载外部模型,而无需修改 TPU 推理的代码库。

请按以下结构组织您的插件:

├── setup.py
├── your_code
   ├── your_code.py
   └── __init__.py

setup.py 构建脚本应遵循与 vLLM 插件相同的指南

要注册模型,请在 your_code/__init__.py 中使用以下代码:

from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import register_model

logger = init_logger(__name__)

def register():
    from .your_code import YourModelForCausalLM
    register_model("YourModelForCausalLM", YourModelForCausalLM)

5. 安装并运行您的模型

确保您在与 vllm/tpu 推理相同的 Python 环境中执行 pip install . 安装您的模型。然后,运行您的模型:

HF_TOKEN=token TPU_BACKEND_TYPE=jax \
  python -m vllm.entrypoints.cli.main serve \
  /path/to/hf_compatible/weights/ \
  --max-model-len=1024 \
  --tensor-parallel-size 8 \
  --max-num-batched-tokens 1024 \
  --max-num-seqs=1 \