跳到内容

性能分析

目前有三种方法可以对您的工作负载进行性能分析:

使用 examples/tpu_profiling.py

vLLM TPU 性能分析脚本

此脚本是一个实用工具,用于分析 vLLM 引擎在 TPU VM 上的性能。它使用 JAX profiler 来捕获详细的性能跟踪。

可以使用 TensorBoard(配合 tensorboard-plugin-profile 包)或 Perfetto UI 等工具来可视化性能分析结果。

如何使用

先决条件

您必须安装 TensorBoard profile 插件才能可视化结果。

pip install tensorboard-plugin-profile

基本命令

该脚本从命令行运行,指定工作负载参数和任何必需的 vLLM 引擎参数。

python3 examples/tpu_profiling.py --model <your-model-name> [OPTIONS]

关键参数

  • --model: (必需) 要进行性能分析的模型名称或路径。
  • --input-len: 每个请求的输入提示 token 长度。
  • --output-len: 每个请求要生成的 token 数量。
  • --batch-size: 请求的数量。
  • --profile-result-dir: JAX profiler 输出将保存的目录。
  • 该脚本还接受所有标准的 vLLM EngineArgs(例如,--tensor-parallel-size, --dtype)。

示例

1. 分析 Prefill 操作: 要分析具有长输入提示(例如,1024 个 token)的单个请求,请将 --input-len 设置得高一些,并将 --batch-size 设置为 1。

python3 examples/tpu_profiling.py \
  --model google/gemma-2b \
  --input-len 1024 \
  --output-len 1 \
  --batch-size 1

2. 分析 Decoding 操作: 要分析大量单 token 解码步骤的批处理,请将 --input-len--output-len 设置为 1,并使用较大的 --batch-size

python3 examples/tpu_profiling.py \
  --model google/gemma-2b \
  --input-len 1 \
  --output-len 1 \
  --batch-size 256

使用 PHASED_PROFILING_DIR

如果您设置了以下环境变量

PHASED_PROFILING_DIR=<DESIRED PROFILING OUTPUT DIR>

我们将自动在工作负载的三个阶段捕获性能分析(假设它们会出现):1. Prefill 密集型(给定批次的 prefill / 总计划 token 的商值 => 0.9)2. Decode 密集型(给定批次的 prefill / 总计划 token 的商值 <= 0.2)3. 混合型(给定批次的 prefill / 总计划 token 的商值在 0.4 到 0.6 之间)。

为了便于您的分析,我们还将记录被分析批次的批次构成。

使用 USE_JAX_PROFILER_SERVER

如果您设置了以下环境变量

USE_JAX_PROFILER_SERVER=True

您可以改为手动决定何时捕获性能分析以及捕获多长时间,这对于您的工作负载(例如,E2E 基准测试)可能很有帮助,因为它很大,并且对整个工作负载进行性能分析(即使用上述方法)会生成一个巨大的跟踪文件。

您还可以设置所需的性能分析端口(默认为 9999)。

JAX_PROFILER_SERVER_PORT=XXXX

要使用此方法,您可以执行以下操作:

  1. 运行您典型的 vllm serveoffline_inference 命令(确保设置 USE_JAX_PROFILER_SERVER=True)。
  2. 运行您的基准测试命令(python benchmark_serving.py...)。
  3. 预热完成后且基准测试正在运行时,启动一个新的 tensorboard 实例,并将您的 logdir 设置为您所需的性能分析输出位置(例如,tensorboard --logdir=profiles/llama3-mmlu/)。
  4. 打开 tensorboard 实例并导航到 profile 页面(例如,https://:6006/#profile)。
  5. 点击 Capture Profile,然后在 Profile Service URL(s) or TPU name 框中,输入 localhost:XXXX,其中 XXXX 是您的 JAX_PROFILER_SERVER_PORT(默认为 9999)。

  6. 输入所需的时间(以毫秒为单位)。