将自定义JAX模型作为插件集成¶
本指南将引导您完成为TPU推理实现基本JAX模型的步骤。
1. 引入您的模型代码¶
本指南假定您的模型是为JAX编写的。
2. 使您的代码与vLLM兼容¶
为确保与vLLM兼容,您的模型必须满足以下要求:
初始化代码
模型中的所有vLLM模块在其构造函数中都必须包含一个vllm_config参数。这将包含所有vllm相关的配置以及模型配置。
初始化代码应如下所示:
class LlamaForCausalLM(nnx.Module):
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
mesh: Mesh) -> None:
self.vllm_config = vllm_config
self.rng = nnx.Rngs(rng_key)
self.mesh = mesh
self.model = LlamaModel(
vllm_config=vllm_config,
rng=self.rng,
mesh=mesh,
)
计算代码
模型的正向传播应在__call__方法中,该方法至少必须包含以下参数:
def __call__(
self,
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
) -> Tuple[List[jax.Array], jax.Array]:
…
有关参考,请查看我们的Llama实现。
3. 实现权重加载逻辑¶
您现在需要在您的*ForCausalLM类中实现load_weights方法。此方法应从HuggingFace的检查点文件(或兼容的本地检查点)加载权重,并将其分配给您模型中相应的层。
4. 注册您的模型¶
TPU推理依赖于模型注册表来确定如何运行每个模型。预注册架构的列表可以在这里找到。
如果您的模型不在该列表中,您必须将其注册到TPU推理。您可以使用插件(类似于vLLM的插件)来加载外部模型,而无需修改TPU推理的代码库。
您的插件结构应如下所示:
setup.py构建脚本应遵循与vLLM插件相同的指导。
要注册模型,请在your_code/__init__.py中使用以下代码:
from tpu_inference.logger import init_logger
from tpu_inference.models.common.model_loader import register_model
logger = init_logger(__name__)
def register():
from .your_code import YourModelForCausalLM
register_model("YourModelForCausalLM", YourModelForCausalLM)
5. 安装并运行您的模型¶
确保您在与vllm/tpu推理相同的Python环境中pip install .您的模型。然后运行您的模型: