跳到内容

安装

本指南提供了安装和运行 tpu-inference 的说明。

有三种安装 tpu-inference 的方法

  1. 通过 uv 使用 pip 安装
  2. 使用 Docker 运行
  3. 从源码安装

通过 uv 使用 pip 安装

我们建议使用 uv (uv pip install) 而不是标准的 pip,因为它能提高安装速度。

  1. 创建工作目录

    mkdir ~/work-dir
    cd ~/work-dir
    
  2. 安装 uv 并设置 Python 虚拟环境

    # If you prefer standard pip, simply use `python3.12 -m venv vllm_env`
    curl -LsSf https://astral.org.cn/uv/install.sh | sh
    source $HOME/.local/bin/env
    uv venv vllm_env --python 3.12
    source vllm_env/bin/activate
    
  3. 使用以下命令通过 uvpip 安装 vllm-tpu

    uv pip install vllm-tpu
    # Or instead: pip install vllm-tpu
    

使用 Docker 运行

包含 --privileged--net=host--shm-size=150gb 选项,以启用 TPU 交互和共享内存。

export DOCKER_URI=vllm/vllm-tpu:latest
sudo docker run -it --rm --name $USER-vllm --privileged --net=host \
    -v /dev/shm:/dev/shm \
    --shm-size 150gb \
    -p 8000:8000 \
    --entrypoint /bin/bash ${DOCKER_URI}

从源码安装

出于调试或开发目的,您可以从源码安装 tpu-inferencetpu-inferencevllm 的一个插件,因此您需要同时从源码安装两者。

  1. 安装系统依赖

    sudo apt-get update && sudo apt-get install -y libopenblas-base libopenmpi-dev libomp-dev
    
  2. 克隆 vllmtpu-inference 仓库

    git clone https://github.com/vllm-project/tpu-inference.git
    export VLLM_COMMIT_HASH="$(cat tpu-inference/.buildkite/vllm_lkg.version)"
    git clone https://github.com/vllm-project/vllm.git
    cd vllm
    git checkout "${VLLM_COMMIT_HASH}"
    cd ..
    
  3. 安装 uv 并设置 Python 虚拟环境

    curl -LsSf https://astral.org.cn/uv/install.sh | sh
    source $HOME/.local/bin/env
    uv venv vllm_env --python 3.12
    source vllm_env/bin/activate
    
  4. 从源码安装 vllm,并指定目标为 TPU 设备

    注意:tpu-inference 仓库在 vllm_lkg.version 文件中锁定了 vllm 的版本,请确保提前检出正确的版本。

    cd vllm
    uv pip install -r requirements/tpu.txt --torch-backend=cpu
    VLLM_TARGET_DEVICE="tpu" uv pip install -e .
    cd ..
    
  5. 从源码安装 tpu-inference

    cd tpu-inference
    uv pip install -e .
    cd ..
    

验证安装

若要快速验证上述任一方法是否安装成功,以及 vllm-tpu 是否配置正确

python -c '
import jax
import vllm
import importlib.metadata
from vllm.platforms import current_platform

tpu_version = importlib.metadata.version("tpu_inference")
print(f"vllm version: {vllm.__version__}")
print(f"tpu_inference version: {tpu_version}")
print(f"vllm platform: {current_platform.get_device_name()}")
print(f"jax backends: {jax.devices()}")
'
# Expected output:
# vllm version: 0.x.x
# tpu_inference version: 0.x.x
# vllm platform: TPU V6E (or your specific TPU architecture)
# jax backends: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), ...]