以插件形式接入自定义 JAX 模型¶
本指南将引导您完成将基础 JAX 模型接入 TPU 推理的步骤。
1. 引入模型代码¶
本指南假设您的模型是基于 JAX 编写的。
2. 使代码兼容 vLLM¶
为了确保与 vLLM 的兼容性,您的模型必须满足以下要求:
初始化代码
模型内的所有 vLLM 模块必须在其构造函数中包含一个 vllm_config 参数。该参数包含所有与 vLLM 相关的配置以及模型配置。
初始化代码应如下所示:
class LlamaForCausalLM(nnx.Module):
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
mesh: Mesh) -> None:
self.vllm_config = vllm_config
self.rng = nnx.Rngs(rng_key)
self.mesh = mesh
self.model = LlamaModel(
vllm_config=vllm_config,
rng=self.rng,
mesh=mesh,
)
计算代码
模型的前向传播逻辑应位于 __call__ 中,且必须至少包含以下参数:
def __call__(
self,
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array]:
…
作为参考,请查看我们的 Llama 实现。
3. 实现权重加载逻辑¶
现在,您需要在 *ForCausalLM 类中实现 load_weights 方法。此方法应从 HuggingFace 检查点文件(或兼容的本地检查点)加载权重,并将其分配给模型中的相应层。
4. 注册您的模型¶
TPU 推理依赖于模型注册表来确定如何运行每个模型。预注册架构列表可在此处找到。
如果您的模型不在列表中,则必须将其注册到 TPU 推理中。您可以使用插件(类似于 vLLM 插件)加载外部模型,而无需修改 TPU 推理的代码库。
请按以下结构组织您的插件:
setup.py 构建脚本应遵循与 vLLM 插件相同的指南。
要注册模型,请在 your_code/__init__.py 中使用以下代码:
from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import register_model
logger = init_logger(__name__)
def register():
from .your_code import YourModelForCausalLM
register_model("YourModelForCausalLM", YourModelForCausalLM)
5. 安装并运行您的模型¶
确保您在与 vllm/tpu 推理相同的 Python 环境中执行 pip install . 安装您的模型。然后,运行您的模型: