跳到内容

JAX 模型开发指南

tpu-inference 提供了一个灵活的框架,用于在 Flax NNX 中实现基于 Transformer 的架构。

集成新模型类型所需的组件包括:- 定义模型架构并实现任何自定义层 - 实现权重加载逻辑 - (可选) 添加量化支持 - 将新模型注册到 tpu-inference 中

代码组织

在开始模型开发之前,熟悉代码组织会很有帮助。

tpu_inference
├── layers
   ├── jax # Provide pre-implemented building blocks for tpu-inference models.
       ├── attention_interface.py # Core interfaces used for applying attention.
       ├── base.py
       ├── layers.py
       ├── transformer_block.py
       ├── sharding.py
       ├── rope.py
       ├── glossary.md
       ├── attention
           ├── attention.py # Pre-implemented attention layer.
           └── deepseek_v3_attention.py
       └── moe
            ├── moe.py
            └── deepseek_v3_moe.py
   └── common # Functionalities shared between torchax and jax implementations.
└── models
   ├── common
      └── model_loader.py
   └── jax  # Contains model files for each type of model family.
       ├── deepseek_v3.py
       ├── llama3.py
       ├── qwen3.py
       └── utils

模型实现

实现新模型需要创建一个专用模型文件(例如,deepseek_v3.py),其中包含以下组件:- 定义架构的模型类。- 前向传播实现和 logits 计算。- 权重加载逻辑,用于将 HuggingFace 权重导入模型定义。

定义模型架构

模型文件旨在包含定义基于 Transformer 的架构所需的所有信息。每个模型文件都包含一个具有以下构造函数接口的模型类:

class NewModel(nnx.Module):
  def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
               mesh: jax.sharding.Mesh)

构造函数应设置架构配置(例如,num_layers、hidden_size)并初始化模型层。可以使用 flax NNX 从头开始定义层(例如,Llama3),或者可以利用 tpu-inference 来导入或扩展常用的层类型(例如,EmbedderRMSNormMoEAttentionDenseFFWTransformerBlock)。

实现前向传播

前向传播包含将模型构造函数中定义的层进行拼接的逻辑,并应使用以下接口:

def __call__(
   self,
   kv_caches: List[jax.Array],
   input_ids: jax.Array,
   attention_metadata: AttentionMetadata,
   *args,
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]

此接口的关键假设是上下文由模型外部管理(模型负责在自注意力之后更新 KV 缓存张量是例外),这与 vLLM 中的情况一致。(有关 AttentionMetadata 如何准备的更多详细信息,请参阅 vLLM 的 Block 调度和管理设计tpu_jax_runner.py)。预期返回的输出包含更新的 KV 缓存、最终层隐藏状态以及(可选)辅助的最终隐藏状态残差(用于投机解码)。

除了前向传播逻辑之外,每个模型都需要实现一个使用以下接口生成 logits 的方法:def compute_logits(self, hidden_states: jax.Array) -> jax.Array:

权重加载

开源模型的权重在命名和参数形状方面并不 universally standard。因此,有必要实现按模型加载权重的逻辑,以正确地将开源权重导入到相应的模型参数中。为此,每个模型都必须实现一个具有以下接口的 load_weights 方法:def load_weights(self, rng: PRNGKey)

权重加载逻辑通常由几个类别的步骤组成:- 将 HuggingFace 权重加载到迭代器中(参见 model_weights_generator)- 定义加载的权重名称与实现权重名称之间的映射。- 定义要应用于加载参数的张量变换的映射。(这些变换可以包括转置或重塑加载的张量)。- 执行特定于模型的加载逻辑(例如,拆分加载的权重张量并加载到多个参数中)。- (可选)支持加载预量化模型。

有关如何实现权重加载的示例,请参考 deepseek_v3.pyllama4.py

量化支持

许多大型 LLM,如 DeepSeek-V3,使用量化来减少硬件需求和提高性能。tpu-inference 代码库使用 Qwix 来加载预量化模型和/或对加载的模型权重应用额外的量化设置。在 tpu-inference 中,对于预量化检查点如何生成没有假设(因此您可以自由选择流行的工具),只要结果以 HuggingFace Safetensor 格式保存并遵循以下指南。有关如何在 tpu-inference 上使用 Qwix 进行推理运行的更多详细信息,请参阅 通用 readme

请注意,您可能需要在此处更新 TPU 上支持的量化类型列表:这里。如果 HuggingFace 量化配置中的 quant_method 不是支持的类型之一,vLLM 将触发验证错误。HuggingFace 量化配置

为了演示,在本节中,我们将引用 deepseek_v3.py 来获取实现细节。

加载预量化检查点并应用量化规则

要正确加载预量化检查点,需要执行以下步骤:- 使用 Qwix 配置定义量化设置,该配置可以作为 yaml 文件(例如,int8_default.yaml)公开,或者在代码中设置。开源模型的量化设置通常在其各自的 HuggingFace 量化配置中发布(例如,DeepSeek-R1)。(有关支持的 Qwix 量化选项的更多信息,请参阅 Qwix 文档)。- 在 Qwix 配置中将 use_abstract_model 设置为 True,以便在加载权重之前对 NNX 模型图进行量化。- 如果预量化模型包含反量化标度,请更新权重加载逻辑以存储它们。如果加载模型的权重需要应用变换,请确保反量化标度也进行相应变换。标度维度可以通过 HuggingFace 配置中的 weight_block_size 来确定,并在 权重加载逻辑中设置。标度维度也可以与 Safetensor 文件 进行交叉引用。

反之,如果检查点未预量化,则不需要自定义模型加载代码,应在 Qwix 配置中将 use_abstract_model 设置为 False

请注意,Qwix 量化设置是事实上的标准,将覆盖加载权位使用的数据类型(即使提供了预量化权重)。

模型注册

一旦实现了新的模型类型,就必须将其添加到 model_loader.py 的模型注册表中。

警告

根据 vLLM 的验证流程,模型必须注册为一个受支持的 HuggingFace 模型名称(有关更多详细信息,请参阅 此处)。

要将外部 Jax NNX 模型实现集成到 tpu-inference 中,请参阅 专用文档