跳到内容

torch.compile 集成

在 vLLM 的 V1 架构中,torch.compile 默认启用,并且是该框架的关键组成部分。本文档提供了一个简单的演练示例,展示如何理解 torch.compile 的用法。

在整个示例中,我们将运行一个通用的 Llama 模型,并开启调试级别日志以显示所有详细信息。要使用的命令是 VLLM_LOGGING_LEVEL=DEBUG vllm serve meta-llama/Llama-3.2-1B

注意

有关 torch.compile 集成的更多信息和最新进展,请参阅这篇 博客文章

编译缓存

在非常详细的日志中,我们可以看到

INFO 03-07 03:06:55 [backends.py:409] Using cache directory: ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0 for vLLM's torch.compile

vLLM 将考虑所有可用因素,并决定一个目录来存储所有编译产物。这意味着,您可以直接在部署场景中复制整个 ~/.cache/vllm/torch_compile_cache 目录,以节省大量的编译时间,从而加快 vLLM 实例的启动时间。

考虑的因素包括

  • 所有相关的配置(请参阅 配置文件夹)中 compute_hash 函数
  • PyTorch 配置(请参阅 compiler_interface.py)中 compute_hash 函数
  • 模型的 forward 函数以及 forward 函数调用的相关函数(见下文)

考虑到所有这些因素,通常我们可以保证缓存是安全的,并且不会导致任何意外行为。因此,缓存默认是启用的。如果您想调试编译过程,或者怀疑缓存导致了某些问题,可以通过设置环境变量 VLLM_DISABLE_COMPILE_CACHE=1 来禁用它。

vLLM torch.compile 集成的一个独特之处在于,我们保证所有编译都在提供任何请求之前完成。没有请求会触发新的编译。否则,引擎就会在该请求上阻塞,响应时间会出现意外的峰值。

默认情况下,缓存将编译的产物保存为二进制文件。如果您想为了调试目的与生成的代码进行交互,请在编译配置中将字段 compile_cache_save_format 设置为 unpacked,或者省略该字段并设置环境变量 VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked

动态形状和 vllm 守卫移除

torch.compile 被设计为在需要时毫不犹豫地守护动态形状。这与 vLLM torch.compile 移除守卫的方法相矛盾,因为许多守卫可能是实质性的。

torch.compile 提供了两种动态形状:backedunbackedtorch.compile 会守护 backed 动态形状,并且不保证不会向其添加任何守卫。用户代码、dynamo、inductor 和 autograd 都可以添加守卫。此外,对于 0/1 特化,即使在这些范围内没有遇到分支,backed 符号也会无条件地特化为 0、1 或 >=2。

相反,unbacked 动态形状保证不会被守护,也不会进行 0/1 特化。然而,当遇到需要其值的分支并且没有定义显式的 unbacked 处理时,可能会抛出数据依赖错误。框架正在趋向于一个不会抛出 DDE 而是选择通用路径的状态。使用 unbacked 的一个缺点是由于性能错误或选择通用路径而错失了优化机会,并且使用了基于固定非示例输入的提示(这将很快通过 override_hint API 修复)。选择通用路径的一个示例是,在不能通过引入 clone 来符号化证明的情况下,假设输入不连续,在函数调用 contiguous() 和 reshape() 时。

backed_size_oblivious 是一个标志,它允许在定义了显式的 unbacked 处理的地方将 backed 符号视为 unbacked。在此模式下,框架代码中几乎避免了 0/1 特化,并且默认的 0/1 特化不会发生。然而,仍然不能保证 torch.compile 不会发生守护,尤其是由于用户代码或自定义 pass。backed_size_oblivious 在 PyTorch compile 中是实验性的,并且可能被弃用。尽管如此,它比 backed 更安全的选择,并且性能下降的可能性比 unbacked 更低。

配置动态形状

DynamicShapesConfig 允许您通过设置 type 字段来控制动态形状的行为。您可以在三种模式之间进行选择:BACKED(默认)、UNBACKEDBACKED_SIZE_OBLIVIOUS

离线推理示例(使用 LLM 类)

在使用 LLM 类进行离线推理时,您可以通过 compilation_config 参数配置动态形状

from vllm import LLM, SamplingParams
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType

# Example: Using backed_size_oblivious (experimental, safer than backed)
llm = LLM(
    model="meta-llama/Llama-3.2-1B",
    compilation_config=CompilationConfig(
        dynamic_shapes_config=DynamicShapesConfig(
            type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
        )
    )
)

# Example: Using unbacked (strongest guarantee against guards)
llm = LLM(
    model="meta-llama/Llama-3.2-1B",
    compilation_config=CompilationConfig(
        dynamic_shapes_config=DynamicShapesConfig(
            type=DynamicShapesType.UNBACKED
        )
    )
)

# Generate outputs
prompts = ["Hello, my name is", "The future of AI is"]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)

在线服务示例(使用 vllm serve)

在使用 vllm serve 进行在线服务时,您可以通过 --compilation-config 标志配置动态形状

# Example: Using unbacked
vllm serve meta-llama/Llama-3.2-1B \
  --compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'


# Alternative: Using dot notation (simpler for single values)
vllm serve meta-llama/Llama-3.2-1B -cc.dynamic_shapes_config.type=unbacked

选择合适的模式

  • BACKED(默认):当您愿意为了最大性能而接受潜在的不安全守卫移除时使用。守卫可能会被不安全地添加然后忽略。

  • UNBACKED:当您需要最强的反守卫保证时使用。这是最保守的选择,但可能会错过一些优化机会。

  • BACKED_SIZE_OBLIVIOUS:当您希望在避免守卫和性能之间取得平衡时使用。这种实验模式比 BACKED 更安全,但仍不如 UNBACKED 保守。

Python 代码编译

在非常详细的日志中,我们可以看到

日志
DEBUG 03-07 03:06:52 [decorators.py:203] Start compiling function <code object forward at 0x7f08acf40c90, file "xxx/vllm/model_executor/models/llama.py", line 339>

DEBUG 03-07 03:06:54 [backends.py:370] Traced files (to be considered for compilation cache):
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/_dynamo/polyfills/builtins.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/container.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/module.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/attention/layer.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/communication_op.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/parallel_state.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/custom_op.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/activation.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/layernorm.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/linear.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/rotary_embedding.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/vocab_parallel_embedding.py
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/models/llama.py

DEBUG 03-07 03:07:07 [backends.py:462] Computation graph saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py
DEBUG 03-07 03:07:07 [wrapper.py:105] Dynamo transformed code saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py

这是关于 Python 代码编译,即 Dynamo 的图捕获。它尝试跟踪代码 xxx/vllm/model_executor/models/llama.py:339 中的函数,这是我们编译的模型中的 forward 函数。在 forward 传递期间,Dynamo 还会调用和内联其他函数,如日志所示,包括来自 xxx/torch/nn/modules/module.py 的一些 PyTorch 函数(由 PyTorch nn.Module 使用,因为模块属性访问会触发函数调用),以及来自 vLLM 的一些通信/注意力/激活函数。所有跟踪的文件都将在我们决定使用的缓存目录时被考虑。这样,上述文件中的任何代码更改都将触发编译缓存未命中,从而导致重新编译。

Dynamo 编译的结果是一个存储在 ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py 中的新函数。通常,这个函数会将张量从模块解包,然后传递给跟踪的计算图。计算图存储在 ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py 中。

计算图处理

计算图为每个张量都带有形状注解。输入是输入 ID、位置 ID、模型中的权重和缓冲区,输出是最终的隐藏状态。请注意,LM 头投影和采样操作不包含在图中。

计算图的大部分输入都具有静态形状,因为它们是模型权重和缓冲区,在模型生命周期内不会改变。只有输入 ID 和位置 ID 具有符号形状,即形状可以从批次到批次改变。然而,它们将共享相同的符号形状。也就是说,计算图唯一改变的大小是批次大小(当前 forward 传递中处理的 token 数量)。

注意力操作很复杂,它需要与具有复杂形状的 kv 缓存进行交互。幸运的是,注意力操作的输出与注意力操作的输入查询共享相同的形状。因此,我们将整个注意力操作封装到一个 PyTorch 自定义 op torch.ops.vllm.unified_attention_with_output 中,这样 Dynamo 就不会尝试检查任何内部操作。这样,尽管注意力操作很复杂,我们仍然可以从 Dynamo 的角度将模型的计算图捕获为一个完整的图。

计算图由 splitting_ops(通常是注意力操作)进一步分割成多个部分。因此,在 ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py 文件中,我们可以看到许多子模块,每个子模块都是分割后的图的一部分。

  • 注意力操作本身就是一个子模块。
  • 计算图的一部分,从一个注意力操作到下一个注意力操作,就是一个子模块。

每个子模块都可以通过其索引来识别,并将单独处理。

计算图编译

在非常详细的日志中,我们还可以看到

DEBUG 03-07 03:52:37 [backends.py:134] store the 0-th graph for shape None from inductor via handle ('fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py')
DEBUG 03-07 03:52:39 [backends.py:134] store the 1-th graph for shape None from inductor via handle ('f7fmlodmf3h3by5iiu2c4zarwoxbg4eytwr3ujdd2jphl4pospfd', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/ly/clyfzxldfsj7ehaluis2mca2omqka4r7mgcedlf6xfjh645nw6k2.py')
...
DEBUG 03-07 03:52:45 [backends.py:134] store the 15-th graph for shape None from inductor via handle ('f7fmlodmf3h3by5iiu2c4zarwoxbg4eytwr3ujdd2jphl4pospfd', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/ly/clyfzxldfsj7ehaluis2mca2omqka4r7mgcedlf6xfjh645nw6k2.py')
DEBUG 03-07 03:52:45 [backends.py:134] store the 16-th graph for shape None from inductor via handle ('fvj3ccoi7m34f3dnr4itmu55mmun44l5xymwhrjlwisylsk7q6jy', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/tf/ctfftkglj7b4lcttq5cymx6cew372uoauupqn6ldsvpiucavqcjc.py')

这意味着第一个计算图片段(符号形状为 None)由 Inductor 编译(使用键 fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw)。编译后的内核存储在 ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py。您可以打开该文件查看 Inductor 最终运行的代码。

再补充一点:可以看到第 1 个图和第 15 个图具有相同的键,而第 0 个图和第 16 个图不同。这是预期的,因为我们按注意力 op 分割图,得到 3 个唯一的子图:

  • 注意力之前的第一个层
  • 每个中间层,从一个注意力操作到下一个注意力操作
  • 注意力之后的最后一个层

如果我们已经有了缓存目录(例如,第二次运行相同的代码),我们将看到以下日志:

DEBUG 03-07 04:00:45 [backends.py:86] Directly load the 0-th graph for shape None from inductor via handle ('fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py')

这次,Inductor 编译被完全绕过,我们将从磁盘加载上次编译获得的产物。

上面的示例仅使用 Inductor 为通用形状(即符号形状)进行编译。我们也可以使用 Inductor 为某些特定形状进行编译,例如:

vllm serve meta-llama/Llama-3.2-1B \
  --compilation_config '{"compile_sizes": [1, 2, 4, 8]}'

然后它还将为批次大小为 1、2、4、8 的特定大小编译一个内核。此时,计算图中的所有形状都是静态且已知的,我们将启用自动调优以获得最大性能。首次运行时可能很慢,但下次运行时,我们可以直接绕过调优并运行已调优的内核。

当所有形状都已知时,torch.compile 可以比较不同的配置,并且经常能找到更好的配置来运行内核。例如,我们可以看到以下日志:

日志
AUTOTUNE mm(8x2048, 2048x3072)
  triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
  triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  mm 0.0160 ms 81.6%
  triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
  triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
  triton_mm_7 0.0203 ms 64.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_2 0.0208 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_11 0.0215 ms 60.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.0428 seconds and 7.5727 seconds precompiling

这意味着,对于形状为 8x2048x3072 的矩阵乘法,torch.compile 尝试了具有各种配置的 triton 模板,并且比默认代码(分派到 cublas 库)快得多。

不幸的是,由于自动调优需要很长时间(从几秒到几分钟,取决于模型大小和批次大小),即使它可以缓存以备后用,为了用户友好性,默认情况下我们将其关闭。如果您想要最大性能,建议尝试通过编译特定形状来启用它。

Cudagraph 捕获

vLLM 的 V1 架构使用与分块编译相匹配的分块 cudagraph。如上所述,完整的计算图被分割,我们只为注意力操作之间的图块捕获 cudagraph(包括任何注意力操作之前的第一个图,以及所有注意力操作之后的最后一个图)。这是基于一个常见的观察:注意力之间的计算通常是 token-wise 的,并且易于为 cudagraph 处理;而注意力操作对于 cudagraph 兼容性来说是非平凡的。因此,通过在 Eager 模式下运行注意力操作,而在其他操作中使用 cudagraph,我们保持了注意力操作的灵活性。

分块 cudagraph 还具有细粒度的内存管理。目的是仅将注意力内核从 cudagraph 中排除,同时将所有其他模块和内存分配操作保留在 cudagraph 中。这就是为什么 V1 中的注意力操作将输出张量作为注意力输入的缘由。

cudagraphs 由编译器后端捕获和管理,并在批次大小与已捕获的 cudagraph 匹配时重放。模型调用者(模型运行器)只需确保正确管理输入缓冲区。所有中间缓冲区均由编译器后端自动管理。

默认情况下,vLLM 将尝试确定一组用于捕获 cudagraph 的大小。您也可以使用配置 cudagraph_capture_sizes 来覆盖它。

vllm serve meta-llama/Llama-3.2-1B \
  --compilation-config '{"cudagraph_capture_sizes": [1, 2, 4, 8]}'

然后,它将仅为指定的大小捕获 cudagraph。这对于对 cudagraph 捕获进行细粒度控制可能很有用。

完整的 Cudagraph 捕获

如果使用与 cudagraph 兼容的注意力后端,则可以将注意力包含在 cudagraph 中。这可以在某些情况下(如小型模型或 MOE 的解码速度)提高性能。有关更多详细信息,请参阅 CUDA Graphs