跳过内容

NVIDIA TensorRT 模型优化器

The NVIDIA TensorRT Model Optimizer 是一个旨在优化模型以便在 NVIDIA GPU 上进行推理的库。它包含用于大型语言模型 (LLM)、视觉语言模型 (VLM) 和扩散模型的训练后量化 (PTQ) 和量化感知训练 (QAT) 工具。

我们建议使用以下方式安装该库

pip install nvidia-modelopt

使用 PTQ 量化 HuggingFace 模型

您可以使用 TensorRT Model Optimizer 仓库中提供的示例脚本来量化 HuggingFace 模型。LLM PTQ 的主要脚本通常位于 examples/llm_ptq 目录中。

下面是一个使用 modelopt 的 PTQ API 量化模型的示例

import modelopt.torch.quantization as mtq
from transformers import AutoModelForCausalLM

# Load the model from HuggingFace
model = AutoModelForCausalLM.from_pretrained("<path_or_model_id>")

# Select the quantization config, for example, FP8
config = mtq.FP8_DEFAULT_CFG

# Define a forward loop function for calibration
def forward_loop(model):
    for data in calib_set:
        model(data)

# PTQ with in-place replacement of quantized modules
model = mtq.quantize(model, config, forward_loop)

模型量化后,您可以使用 export API 将其导出为量化检查点

import torch
from modelopt.torch.export import export_hf_checkpoint

with torch.inference_mode():
    export_hf_checkpoint(
        model,  # The quantized model.
        export_dir,  # The directory where the exported files will be stored.
    )

然后可以将量化检查点与 vLLM 一起部署。例如,以下代码展示了如何使用 vLLM 部署 nvidia/Llama-3.1-8B-Instruct-FP8,这是从 meta-llama/Llama-3.1-8B-Instruct 衍生的 FP8 量化检查点

from vllm import LLM, SamplingParams

def main():

    model_id = "nvidia/Llama-3.1-8B-Instruct-FP8"
    # Ensure you specify quantization='modelopt' when loading the modelopt checkpoint
    llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True)

    sampling_params = SamplingParams(temperature=0.8, top_p=0.9)

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

if __name__ == "__main__":
    main()