跳到内容

将自定义JAX模型作为插件集成

本指南将引导您完成为TPU推理实现基本JAX模型的步骤。

1. 引入您的模型代码

本指南假定您的模型是为JAX编写的。

2. 使您的代码与vLLM兼容

为确保与vLLM兼容,您的模型必须满足以下要求:

初始化代码

模型中的所有vLLM模块在其构造函数中都必须包含一个vllm_config参数。这将包含所有vllm相关的配置以及模型配置。

初始化代码应如下所示:

class LlamaForCausalLM(nnx.Module):

    def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
                 mesh: Mesh) -> None:
        self.vllm_config = vllm_config
        self.rng = nnx.Rngs(rng_key)
        self.mesh = mesh

        self.model = LlamaModel(
            vllm_config=vllm_config,
            rng=self.rng,
            mesh=mesh,
        )

计算代码

模型的正向传播应在__call__方法中,该方法至少必须包含以下参数:

def __call__(
    self,
    kv_caches: List[jax.Array],
    input_ids: jax.Array,
    attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array]:

有关参考,请查看我们的Llama实现

3. 实现权重加载逻辑

您现在需要在您的*ForCausalLM类中实现load_weights方法。此方法应从HuggingFace的检查点文件(或兼容的本地检查点)加载权重,并将其分配给您模型中相应的层。

4. 注册您的模型

TPU推理依赖于模型注册表来确定如何运行每个模型。预注册架构的列表可以在这里找到。

如果您的模型不在该列表中,您必须将其注册到TPU推理。您可以使用插件(类似于vLLM的插件)来加载外部模型,而无需修改TPU推理的代码库。

您的插件结构应如下所示:

├── setup.py
├── your_code
│   ├── your_code.py
│   └── __init__.py

setup.py构建脚本应遵循与vLLM插件相同的指导

要注册模型,请在your_code/__init__.py中使用以下代码:

from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import register_model

logger = init_logger(__name__)

def register():
    from .your_code import YourModelForCausalLM
    register_model("YourModelForCausalLM", YourModelForCausalLM)

5. 安装并运行您的模型

确保您在与vllm/tpu推理相同的Python环境中pip install .您的模型。然后运行您的模型:

HF_TOKEN=token TPU_BACKEND_TYPE=jax \
  python -m vllm.entrypoints.cli.main serve \
  /path/to/hf_compatible/weights/ \
  --max-model-len=1024 \
  --tensor-parallel-size 8 \
  --max-num-batched-tokens 1024 \
  --max-num-seqs=1 \