如何调试 vLLM-torch.compile 集成¶
TL;DR
- 使用 tlparse 获取 torch.compile 日志。在提交 bug 报告和/或寻求支持时,请包含这些日志。
- vLLM-torch.compile 集成包含多个组件。vLLM 提供了 flags 来关闭每个组件。
| 在线 Flag | 离线 Flag | 结果 |
|---|---|---|
| --enforce-eager | enforce_eager=True | 关闭 torch.compile 和 CUDAGraphs |
| -cc.mode=0 | mode=CompilationMode.NONE | 仅关闭 torch.compile |
| -cc.cudagraph_mode=NONE | compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE) | 仅关闭 CUDAGraphs |
| -cc.backend=eager | compilation_config=CompilationConfig(backend='eager') | 关闭 TorchInductor |
vLLM-torch.compile 概览¶
为了提高性能,vLLM 利用 torch.compile 和 CUDAGraphs 来加速。torch.compile 为 PyTorch 代码生成优化内核,而 CUDAGraphs 则消除了开销。特别地,vLLM-compile 不是 torch.compile,而是一个使用内部 PyTorch Compile API 构建的自定义编译器。
- 给定一个模型,我们通过 TorchDynamo 进行一次完整的图捕获,该捕获是动态的(相对于 batch size,即 token 数量)。
- vLLM 然后可选地分割和/或特化此图,然后使用 TorchInductor 将每个图编译成一个编译后的产物。此步骤可能使用 vLLM 的自定义 Inductor 传递来进一步优化图。
- 编译后的产物会被保存在 vLLM 的 compile 缓存中,以便将来加载。
- vLLM 应用 CUDAGraphs 来减少 CPU 开销。
在所有四个步骤中都可能出现问题。当出现问题时,请尽量隔离出现问题的子系统——这将允许您关闭最少量的组件,以在最小化性能影响的同时保持可靠性目标,并且在您打开 bug 报告时也有助于我们(vLLM)。
有关设计的更多详细信息,请参阅以下资源:
使用 tlparse¶
使用 tlparse 获取 torch.compile 日志。这些日志显示了编译过程的所有阶段,包括 torch.compile 生成的融合内核。如果可能,我们建议将这些日志或其中的一部分与任何 bug 报告一起发送——它们非常有帮助。
安装 tlparse
用法(离线推理)
用法(服务)
TORCH_TRACE=~/trace_dir vllm serve
# ctrl-c out of the server
tlparse ~/trace_dir/<the_first_log_file>
tlparse 命令会输出一些 HTML 文件(可能输出到例如 ./tl_out/index.html)。打开它即可查看日志。它看起来会像这样:
关闭 vLLM-torch.compile 集成¶
传递 --enforce-eager 参数以关闭 vLLM-torch.compile 集成,并完全以 eager 模式运行。这包括关闭 CUDAGraphs。
要仅关闭 torch.compile,请将 mode = NONE 传递给 compilation config。(-cc 是 --compilation_config 的缩写)
# Offline
from vllm.config.compilation import CompilationConfig, CompilationMode
LLM(model, compilation_config=CompilationConfig(mode=CompilationMode.NONE))
要仅关闭 CUDAGraphs,请传递 cudagraph_mode = NONE。
# Offline
from vllm.config.compilation import CompilationConfig, CUDAGraphMode
LLM(model, compilation_config=CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE))
调试 TorchDynamo¶
vLLM 要求模型代码能够通过 TorchDynamo(torch.compile 的前端)捕获成一个完整的图。TorchDynamo 不支持 Python 的所有功能。如果在 fullgraph 模式下遇到不支持的特性,它会报错(有时被称为图中断)。
如果您遇到图中断,请在 pytorch/pytorch 上提交一个 issue,以便 PyTorch 开发人员可以优先处理。然后,请尽力重写代码以避免图中断。有关更多信息,请参阅此Dynamo 指南。
调试动态形状完整图捕获¶
vLLM 要求模型的 forward pass 能够捕获成一个对 batch size(即 token 数量)动态的完整图。它(默认情况下)将此单个图编译成一个产物,并为所有 batch size 使用此产物。
如果您的代码无法用动态形状捕获,您可能会遇到静默的错误、明显的错误或 CUDA 非法内存访问。例如,以下代码无法捕获成单个图:
这个问题很容易诊断。使用 tlparse 并点击 compilation_metrics:它会告诉您关于 batch size 的符号约束。如果存在任何限制 batch size 的约束,那么我们就遇到了问题。
为避免此问题,请执行以下任一操作:
- 避免基于 token 数量进行分支。
- 将分支逻辑封装到自定义算子中。TorchDynamo 不会跟踪进入自定义算子。
调试约束冲突和动态形状 guard 问题¶
动态形状 guard 是 Dynamo guard 的一个特定类别。它们是 torch.compile 应用于动态维度(例如 seq_len)的约束,以确保编译后的产物保持有效。这些 guard 通常在框架代码、自定义传递或用户代码基于动态形状值进行分支时出现。
示例
这会创建一个 guard x > 10 或 x <= 10,具体取决于跟踪的路径。
vLLM 的假设: vLLM 假设所有由 torch.compile 添加的 guard 都可以安全地删除,并且不会将编译后的图约束为特定的输入形状。当此假设被违反时,可能会导致用户需要调试的问题。一些表明此假设被违反的副作用是运行时错误或 ConstraintViolationErrors。
如果动态形状被约束为单个值,则会抛出 ConstraintViolationErrors。如果您遇到约束冲突错误或怀疑动态形状 guard 被错误地添加,您可以使用更严格的动态形状模式来帮助调试问题。
# Online - using unbacked mode
vllm serve meta-llama/Llama-3.2-1B -cc.dynamic_shapes_config.type=unbacked
# Online - using backed_size_oblivious mode
vllm serve meta-llama/Llama-3.2-1B -cc.dynamic_shapes_config.type=backed_size_oblivious
# Offline - using unbacked mode
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
LLM(model, compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
))
# Offline - using backed_size_oblivious mode
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
LLM(model, compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
))
这些模式更严格,减少或消除了对动态形状 guard 的需求,这有助于隔离问题。
unbacked:使用 unbacked symints,它们不允许 guard,从而更容易识别 guard 被错误添加的位置。backed_size_oblivious:使用一种对 guarding 更严格的模式。
有关动态形状模式的更多详细信息,请参阅 动态形状和 vLLM guard 删除。
打印 guards¶
要查看编译过程中添加的所有 guard,您可以使用 TORCH_LOGS=+dynamic。
在日志中查找 [guard added] 以查看 guard 的添加位置。这可以帮助您识别哪些操作导致 guard 被错误添加。
调试 TorchInductor¶
TorchInductor 接收一个捕获的图,然后将其编译成一些 Python 代码,这些代码可能会调用 1+ 个 triton 内核。在罕见(但不幸)的情况下,它可能会生成一个不正确的 triton 内核。这可能表现为静默的错误、CUDA 非法内存访问或明显的错误。
要调试 TorchInductor 是否有问题,您可以通过将 backend='eager' 传递给 compilation config 来禁用它。
如果 Inductor 有问题,请向 PyTorch 提交 bug。如果您想尝试一下,可以调试 Inductor 输出代码中的 triton 内核(您可以通过使用 tlparse 找到它们)。
您还可以使用 TORCH_LOGS=output_code <command> 来打印 Inductor 输出代码。
可编辑的 TorchInductor 代码¶
您可以通过设置 VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked 或传递 -cc.compile_cache_save_format=unpacked 来编辑运行的 TorchInductor 代码。默认是 binary,这意味着它不可编辑。
这是一种有用的技术:您可以在输出代码中设置断点(例如 torch.distributed.breakpoint())和打印语句。
调试 vLLM-compile 缓存¶
vLLM 为 torch.compile 产物构建了自己的缓存。其理念是,产物可以被编译一次,然后在将来重复使用。这是对 torch.compile 的编译器缓存 的一层封装。
虽然 torch.compile 的编译器缓存非常稳定,但 vLLM 的编译器缓存不幸的是并不总是正确的。您可以通过设置 VLLM_DISABLE_COMPILE_CACHE=1 来禁用它。
您也可以手动删除此缓存。
- 使用
rm -rf ~/.cache/vllm删除 vLLM 的编译缓存(查看日志以了解位置是否已更改)。 - 使用
rm -rf /tmp/torchinductor_$(whoami)删除 torch.compile 的内置缓存。
vLLM 的缓存是将缓存键映射到编译后产物的。vLLM 通过组合多个因素(例如,配置 flags 和模型名称)来计算缓存键。如果 vLLM 的编译缓存不正确,这通常意味着缺少某个因素。请参阅 此示例,了解 vLLM 如何计算缓存键的一部分。
调试 CUDAGraphs¶
CUDAGraphs 是一项允许您执行以下操作的功能:
- 将启动 1+ 个 CUDA 内核的可调用对象捕获到 CUDAGraph 中。
- 重放 CUDAGraph。
捕获的 CUDAGraph 包含捕获过程中使用的所有内存。CUDAGraph 的重放会读写完全相同的内存区域。
这导致了一些限制:
- 为了在新数据上使用 CUDAGraphs,您需要将数据复制到 CUDAGraph 正在读取的缓冲区中。
- CUDAGraphs 只捕获 CUDA 内核,它们不捕获在 CPU 上完成的工作。
vLLM 使用原始的 CUDAGraphs API,这在不正确使用时是不安全的。
要仅关闭 CUDAGraphs,请传递 cudagraph_mode = NONE。


