跳到内容

性能分析

目前有三种方法可以分析您的工作负载

使用 examples/tpu_profiling.py

vLLM TPU 性能分析脚本

该脚本是一个用于分析 TPU VM 上 vLLM 引擎性能的实用工具。它使用 JAX 分析器来捕获详细的性能追踪数据。

性能分析结果可以使用 TensorBoard(需安装 tensorboard-plugin-profile 包)或 Perfetto UI 等工具进行可视化。

如何使用

先决条件

您必须安装 TensorBoard 性能分析插件才能可视化结果

pip install tensorboard-plugin-profile

基本命令

该脚本通过命令行运行,需指定工作负载参数以及任何必要的 vLLM 引擎参数。

python3 examples/tpu_profiling.py --model <your-model-name> [OPTIONS]

核心参数

  • --model: (必填)要分析的模型名称或路径。
  • --input-len: 每个请求的输入提示词 token 长度。
  • --output-len: 每个请求生成的 token 数量。
  • --batch-size: 请求数量。
  • --profile-result-dir: JAX 分析器输出的保存目录。
  • 该脚本还接受所有标准的 vLLM EngineArgs(例如 --tensor-parallel-size, --dtype)。

示例

1. 分析预填充 (Prefill) 操作: 若要分析具有长输入提示词(例如 1024 个 token)的单个请求,请将 --input-len 设置得较高,并将 --batch-size 设置为 1。

python3 examples/tpu_profiling.py \
  --model google/gemma-2b \
  --input-len 1024 \
  --output-len 1 \
  --batch-size 1

2. 分析解码 (Decoding) 操作: 若要分析大批量的单 token 解码步骤,请将 --input-len--output-len 设置为 1,并使用较大的 --batch-size

python3 examples/tpu_profiling.py \
  --model google/gemma-2b \
  --input-len 1 \
  --output-len 1 \
  --batch-size 256

使用 PHASED_PROFILING_DIR

如果您设置了以下环境变量

PHASED_PROFILING_DIR=<DESIRED PROFILING OUTPUT DIR>

我们将自动在工作负载的三个阶段捕获配置文件(前提是程序运行到了这些阶段):1. 预填充密集型(给定批次中,预填充 token 与总调度 token 的商 >= 0.9) 2. 解码密集型(给定批次中,预填充 token 与总调度 token 的商 <= 0.2) 3. 混合型(给定批次中,预填充 token 与总调度 token 的商介于 0.4 和 0.6 之间)

为了辅助您的分析,我们还会记录已分析批次的批次组成情况。

使用 USE_JAX_PROFILER_SERVER

如果您设置了以下环境变量

USE_JAX_PROFILER_SERVER=True

您可以改为手动决定何时捕获配置文件以及捕获时长,如果您的工作负载(例如端到端基准测试)规模很大,且对整个工作负载进行分析(即使用上述方法)会生成巨大的追踪文件,那么这种方法非常有用。

此外,您可以设置所需的分析端口(默认为 9999

JAX_PROFILER_SERVER_PORT=XXXX

要使用此方法,您可以执行以下操作

  1. 运行常规的 vllm serveoffline_inference 命令(确保设置 USE_JAX_PROFILER_SERVER=True
  2. 运行您的基准测试命令 (python benchmark_serving.py...)
  3. 预热完成后,在基准测试运行时,启动一个新的 tensorboard 实例,并将 logdir 设置为您希望保存配置文件的输出位置(例如 tensorboard --logdir=profiles/llama3-mmlu/
  4. 打开 tensorboard 实例并导航到 profile 页面(例如 https://:6006/#profile
  5. 点击 Capture Profile,并在 Profile Service URL(s) or TPU name 框中输入 localhost:XXXX,其中 XXXX 是您的 JAX_PROFILER_SERVER_PORT(默认为 9999

  6. 输入所需的持续时间(以毫秒 ms 为单位)