AMD QUARK#

量化可以有效地减少内存和带宽使用,加速计算并提高吞吐量,同时将精度损失降至最低。 vLLM 可以利用 Quark 这个灵活而强大的量化工具包,生成高性能的量化模型,在 AMD GPU 上运行。 Quark 专门支持量化大型语言模型,包括权重、激活和 kv-cache 量化,以及前沿的量化算法,如 AWQ、GPTQ、Rotation 和 SmoothQuant。

Quark 安装#

在量化模型之前,您需要安装 Quark。 最新版本的 Quark 可以通过 pip 安装

pip install amd-quark

您可以参考 Quark 安装指南 了解更多安装详情。

量化过程#

安装 Quark 后,我们将使用一个示例来说明如何使用 Quark。
Quark 量化过程可以列为以下 5 个步骤

  1. 加载模型

  2. 准备校准数据加载器

  3. 设置量化配置

  4. 量化模型并导出

  5. 在 vLLM 中评估

1. 加载模型#

Quark 使用 Transformers 来获取模型和分词器。

from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "meta-llama/Llama-2-70b-chat-hf"
MAX_SEQ_LEN = 512

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto",
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, model_max_length=MAX_SEQ_LEN)
tokenizer.pad_token = tokenizer.eos_token

2. 准备校准数据加载器#

Quark 使用 PyTorch Dataloader 加载校准数据。 有关如何有效使用校准数据集的更多详细信息,请参阅 添加校准数据集

from datasets import load_dataset
from torch.utils.data import DataLoader

BATCH_SIZE = 1
NUM_CALIBRATION_DATA = 512

# Load the dataset and get calibration data.
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
text_data = dataset["text"][:NUM_CALIBRATION_DATA]

tokenized_outputs = tokenizer(text_data, return_tensors="pt",
    padding=True, truncation=True, max_length=MAX_SEQ_LEN)
calib_dataloader = DataLoader(tokenized_outputs['input_ids'],
    batch_size=BATCH_SIZE, drop_last=True)

3. 设置量化配置#

我们需要设置量化配置,您可以查看 quark 配置指南 了解更多详情。 在这里,我们在权重、激活、kv-cache 上使用 FP8 张量量化,量化算法为 AutoSmoothQuant。

注意

请注意,量化算法需要 JSON 配置文件,该配置文件位于 Quark Pytorch 示例examples/torch/language_modeling/llm_ptq/models 目录下。 例如,Llama 的 AutoSmoothQuant 配置文件为 examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json

from quark.torch.quantization import (Config, QuantizationConfig,
                                     FP8E4M3PerTensorSpec,
                                     load_quant_algo_config_from_file)

# Define fp8/per-tensor/static spec.
FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max",
    is_dynamic=False).to_quantization_spec()

# Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC.
global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC,
    weight=FP8_PER_TENSOR_SPEC)

# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
kv_cache_quant_config = {name :
    QuantizationConfig(input_tensors=global_quant_config.input_tensors,
                       weight=global_quant_config.weight,
                       output_tensors=KV_CACHE_SPEC)
    for name in kv_cache_layer_names_for_llama}
layer_quant_config = kv_cache_quant_config.copy()

# Define algorithm config by config file.
LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE =
    'examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json'
algo_config = load_quant_algo_config_from_file(LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE)

EXCLUDE_LAYERS = ["lm_head"]
quant_config = Config(
    global_quant_config=global_quant_config,
    layer_quant_config=layer_quant_config,
    kv_cache_quant_config=kv_cache_quant_config,
    exclude=EXCLUDE_LAYERS,
    algo_config=algo_config)

4. 量化模型并导出#

然后我们可以应用量化。 量化后,我们需要先冻结量化模型,然后再导出。 请注意,我们需要以 HuggingFace safetensors 格式导出模型,您可以参考 HuggingFace 格式导出 了解更多导出格式详情。

import torch
from quark.torch import ModelQuantizer, ModelExporter
from quark.torch.export import ExporterConfig, JsonExporterConfig

# Apply quantization.
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize_model(model, calib_dataloader)

# Freeze quantized model to export.
freezed_model = quantizer.freeze(model)

# Define export config.
LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"]
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP

EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant"
exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR)
with torch.no_grad():
    exporter.export_safetensors_model(freezed_model,
        quant_config=quant_config, tokenizer=tokenizer)

5. 在 vLLM 中评估#

现在,您可以直接通过 LLM 入口点加载和运行 Quark 量化模型

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant",
          kv_cache_dtype='fp8',quantization='quark')
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt:    {prompt!r}")
    print(f"Output:    {generated_text!r}")
    print("-" * 60)

或者,您可以使用 lm_eval 评估准确性

$ lm_eval --model vllm \
  --model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant,kv_cache_dtype='fp8',quantization='quark' \
  --tasks gsm8k

Quark 量化脚本#

除了上面的 Python API 示例外,Quark 还提供了一个 量化脚本,可以更方便地量化大型语言模型。 它支持使用各种不同的量化方案和优化算法来量化模型。 它可以导出量化模型并动态运行评估任务。 使用该脚本,上面的示例可以是

python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
                          --output_dir /path/to/output \
                          --quant_scheme w_fp8_a_fp8 \
                          --kv_cache_dtype fp8 \
                          --quant_algo autosmoothquant \
                          --num_calib_data 512 \
                          --model_export hf_format \
                          --tasks gsm8k