torchax 在 vLLM 中是如何使用的¶
作者:Siyuan Liu, Hongmin Fan, Han Qi
最后更新:2025 年 9 月 26 日
什么是 torchax¶
torchax 是一个提供 JAX 和 PyTorch 互操作性的库。这意味着,您现在可以在同一个进程中运行 PyTorch 和 JAX,并且可以在 JAX 支持的所有硬件上运行(包括 NVidia GPU 和 Google TPU)。
它可以被看作是:* JAX 的 PyTorch 前端,或者,* PyTorch 的 JAX 后端。
使用 torchax,您可以
- 通过少至 2 行代码的修改,在 TPU 上运行 PyTorch 代码。
- 从 PyTorch 函数中调用 JAX 函数,并传递 jax.Array。
- 从 JAX 函数中调用 PyTorch 函数,并传递 torch.Tensor。
- 使用 JAX 的功能,如 jax.grad、optax 和 GSPMD 来训练 PyTorch 模型。
- 使用 PyTorch 模型作为特征提取器,并将其与 JAX 模型一起使用。
其工作原理是拥有一个 torch.Tensor 的子类,该子类包含一个 jax.Array,并实现了该张量应该支持的所有 torch 运算符。
有关 torchax 工作原理的更多详细信息,请参阅此页面。
TPU JAX worker¶
vllm Worker 通过 init_device、determine_available_memory、execute_model、compile_or_warm_up_model、profile 等通用方法与 vllm 的 LLM Engine 进行交互。
每个 worker 都有一个 runner - 主要用于后端特定实现。主要包括以下内容
- 模型初始化 & 权重加载
- 确定 KV 缓存块的数量(代码指针)
- 捕获计算图:使用不同的输入形状运行模型,以运行所有可能的计算图,避免在服务期间进行编译。
- 根据调度器输出执行模型(预处理模型输入,运行模型,为每个请求生成采样 token)
Jax TPU worker 是在 tpu_common 中引入的一个新的 Worker 实现,它负责调用使用 Jax 或 Torch 实现的模型。
Jax worker 与 torch 模型之间的交互¶
当 Jax worker 实例化时,它使用 get_model 函数来获取一个代表模型的 callable。该 callable 是一个纯函数(没有状态),因此权重、KV 缓存以及输入都将作为输入传递给该函数。
当 worker 运行模型时,它会调用模型函数。模型函数接受 Jax Arrays 作为输入。
纯函数¶
JAX 的转换和编译仅设计用于功能上纯粹的 Python 函数:所有输入数据都通过函数参数传递,所有结果都通过函数结果输出。纯函数在输入相同时总是返回相同的结果。
如果在计算中使用了某些数组,但它们不是函数输入,它们将在计算图中被内联为常量。
PyTorch 的 forward 函数不将模型权重作为输入参数 -> 需要使用 torch.func.functional_call。
正如我们在上面的官方文档中看到的,functional_call 允许将权重作为输入传递(而不是从模型对象的属性中读取权重)。
import torch
import torch.nn as nn
from torch.func import functional_call, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
def compute_loss(params, x, t): # params is the weights as a dict of Tensos
y = functional_call(model, params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
KV 缓存¶
通过 functional call,我们可以将权重作为函数输入传递,现在,我们仍然需要处理 KV 缓存。KV 缓存是我们的 model_fn 的输入/输出。然而,传统上 vllm 上游将其作为模型属性写入。
我们可以在调用模型之前手动将 KV 缓存放在那里,并返回它们。
看看下面的说明性示例
def f(cache, x):
cache += x
cache = jnp.zeros(3)
x = jnp.ones(3)
jax.jit(f)(cache, x)
# prints 0
print(cache)
上面的代码模拟了对 cache 变量的就地更新,我们可以看到这种更改并未传播到 jax.jit 区域之外。
def functional_f(cache, x):
cache += x
return cache
cache = jnp.zeros(3)
x = jnp.ones(3)
updated = jax.jit(functional_f)(cache, x)
cache = updated #<-- write back the update
# prints 1
print(cache)
诀窍在于让 jax.jit 中的函数返回更新的 KV 缓存,然后我们重新分配修改。
使用上述技术的代码位于:tpu_inference/models/torchax/torchax_wrapper.py#L55-L83 如下所示
心智模型¶
torchax 就是 JAX。
torchax 的工作原理是提供一个 PyTorch 前端;因此,每个 PyTorch 运算符最终都成为作用于 JAX 数组的 JAX 函数。因此,我们这里采取的方法是:1. 使用处理 JAX 模型相同的 JAX worker。2. 使用 torchax 使 torch.nn.Module 对 worker 看起来像一个 JAX 模型。3. 对 KV 缓存和 Attention kernel(如 RaggedPagedAttention)使用基于 JAX 的方法。
模型执行伪代码
inputs : jax.Array = prepare_inputs() # shared by jax model and torchax model
inputs_torch : torch.Tensor = torch_view(inputs) # a torch.Tensor subclass that holds an jax.Array
outputs_torch : torch.Tensor = torch.func.functional_call(torch_model, weights, inputs_torch) # kv caches are handled in VllmModelWrapper
outputs = jax_view(outputs_torch)
# ...
# sampler logic implemented based on jax, shared by jax model and torchax model
Attention kernel 调用伪代码




