JAX 模型开发指南¶
tpu-inference 提供了一个灵活的框架,用于在 Flax NNX 中实现基于 Transformer 的架构。
集成新模型类型所需的组件包括:- 定义模型架构并实现任何自定义层 - 实现权重加载逻辑 - (可选) 添加量化支持 - 将新模型注册到 tpu-inference 中
代码组织¶
在开始模型开发之前,熟悉代码组织会很有帮助。
tpu_inference
├── layers
│ ├── jax # Provide pre-implemented building blocks for tpu-inference models.
│ │ ├── attention_interface.py # Core interfaces used for applying attention.
│ │ ├── base.py
│ │ ├── layers.py
│ │ ├── transformer_block.py
│ │ ├── sharding.py
│ │ ├── rope.py
│ │ ├── glossary.md
│ │ ├── attention
│ │ │ ├── attention.py # Pre-implemented attention layer.
│ │ │ └── deepseek_v3_attention.py
│ │ └── moe
│ │ ├── moe.py
│ │ └── deepseek_v3_moe.py
│ └── common # Functionalities shared between torchax and jax implementations.
└── models
├── common
│ └── model_loader.py
└── jax # Contains model files for each type of model family.
├── deepseek_v3.py
├── llama3.py
├── qwen3.py
└── utils
- 新 Jax 模型类型的注册应在
tpu_inference/models/common/model_loader.py中执行。 - 新的 Jax 模型定义应添加到
tpu_inference/models/jax。 - 常用的层(例如,嵌入层、前馈层)可以从
tpu_inference/layers/jax导入。 - 特定于模型的层实现应添加到
tpu_inference/layers/<layer_type>/<model_type>_<layer_type>.py(例如,attention/deepseek_v3_attention.py,moe/deepseek_v3_moe.py)。 - 自定义 (Qwix) 量化配置(yaml 文件)应存储在
tpu_inference/models/jax/utils/quantization/configs。
模型实现¶
实现新模型需要创建一个专用模型文件(例如,deepseek_v3.py),其中包含以下组件:- 定义架构的模型类。- 前向传播实现和 logits 计算。- 权重加载逻辑,用于将 HuggingFace 权重导入模型定义。
定义模型架构¶
模型文件旨在包含定义基于 Transformer 的架构所需的所有信息。每个模型文件都包含一个具有以下构造函数接口的模型类:
class NewModel(nnx.Module):
def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
mesh: jax.sharding.Mesh)
构造函数应设置架构配置(例如,num_layers、hidden_size)并初始化模型层。可以使用 flax NNX 从头开始定义层(例如,Llama3),或者可以利用 tpu-inference 来导入或扩展常用的层类型(例如,Embedder、RMSNorm、MoE、Attention、DenseFFW、TransformerBlock)。
实现前向传播¶
前向传播包含将模型构造函数中定义的层进行拼接的逻辑,并应使用以下接口:
def __call__(
self,
kv_caches: List[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
*args,
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]
此接口的关键假设是上下文由模型外部管理(模型负责在自注意力之后更新 KV 缓存张量是例外),这与 vLLM 中的情况一致。(有关 AttentionMetadata 如何准备的更多详细信息,请参阅 vLLM 的 Block 调度和管理设计 和 tpu_jax_runner.py)。预期返回的输出包含更新的 KV 缓存、最终层隐藏状态以及(可选)辅助的最终隐藏状态残差(用于投机解码)。
除了前向传播逻辑之外,每个模型都需要实现一个使用以下接口生成 logits 的方法:def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
权重加载¶
开源模型的权重在命名和参数形状方面并不 universally standard。因此,有必要实现按模型加载权重的逻辑,以正确地将开源权重导入到相应的模型参数中。为此,每个模型都必须实现一个具有以下接口的 load_weights 方法:def load_weights(self, rng: PRNGKey)
权重加载逻辑通常由几个类别的步骤组成:- 将 HuggingFace 权重加载到迭代器中(参见 model_weights_generator)- 定义加载的权重名称与实现权重名称之间的映射。- 定义要应用于加载参数的张量变换的映射。(这些变换可以包括转置或重塑加载的张量)。- 执行特定于模型的加载逻辑(例如,拆分加载的权重张量并加载到多个参数中)。- (可选)支持加载预量化模型。
有关如何实现权重加载的示例,请参考 deepseek_v3.py 或 llama4.py。
量化支持¶
许多大型 LLM,如 DeepSeek-V3,使用量化来减少硬件需求和提高性能。tpu-inference 代码库使用 Qwix 来加载预量化模型和/或对加载的模型权重应用额外的量化设置。在 tpu-inference 中,对于预量化检查点如何生成没有假设(因此您可以自由选择流行的工具),只要结果以 HuggingFace Safetensor 格式保存并遵循以下指南。有关如何在 tpu-inference 上使用 Qwix 进行推理运行的更多详细信息,请参阅 通用 readme。
请注意,您可能需要在此处更新 TPU 上支持的量化类型列表:这里。如果 HuggingFace 量化配置中的 quant_method 不是支持的类型之一,vLLM 将触发验证错误。HuggingFace 量化配置。
为了演示,在本节中,我们将引用 deepseek_v3.py 来获取实现细节。
加载预量化检查点并应用量化规则¶
要正确加载预量化检查点,需要执行以下步骤:- 使用 Qwix 配置定义量化设置,该配置可以作为 yaml 文件(例如,int8_default.yaml)公开,或者在代码中设置。开源模型的量化设置通常在其各自的 HuggingFace 量化配置中发布(例如,DeepSeek-R1)。(有关支持的 Qwix 量化选项的更多信息,请参阅 Qwix 文档)。- 在 Qwix 配置中将 use_abstract_model 设置为 True,以便在加载权重之前对 NNX 模型图进行量化。- 如果预量化模型包含反量化标度,请更新权重加载逻辑以存储它们。如果加载模型的权重需要应用变换,请确保反量化标度也进行相应变换。标度维度可以通过 HuggingFace 配置中的 weight_block_size 来确定,并在 权重加载逻辑中设置。标度维度也可以与 Safetensor 文件 进行交叉引用。
反之,如果检查点未预量化,则不需要自定义模型加载代码,应在 Qwix 配置中将 use_abstract_model 设置为 False。
请注意,Qwix 量化设置是事实上的标准,将覆盖加载权位使用的数据类型(即使提供了预量化权重)。
模型注册¶
一旦实现了新的模型类型,就必须将其添加到 model_loader.py 的模型注册表中。
警告
根据 vLLM 的验证流程,模型必须注册为一个受支持的 HuggingFace 模型名称(有关更多详细信息,请参阅 此处)。
要将外部 Jax NNX 模型实现集成到 tpu-inference 中,请参阅 专用文档。