AMD QUARK#
量化可以有效地减少内存和带宽使用,加速计算并提高吞吐量,同时将精度损失降至最低。 vLLM 可以利用 Quark 这个灵活而强大的量化工具包,生成高性能的量化模型,在 AMD GPU 上运行。 Quark 专门支持量化大型语言模型,包括权重、激活和 kv-cache 量化,以及前沿的量化算法,如 AWQ、GPTQ、Rotation 和 SmoothQuant。
Quark 安装#
在量化模型之前,您需要安装 Quark。 最新版本的 Quark 可以通过 pip 安装
pip install amd-quark
您可以参考 Quark 安装指南 了解更多安装详情。
量化过程#
安装 Quark 后,我们将使用一个示例来说明如何使用 Quark。
Quark 量化过程可以列为以下 5 个步骤
加载模型
准备校准数据加载器
设置量化配置
量化模型并导出
在 vLLM 中评估
1. 加载模型#
Quark 使用 Transformers 来获取模型和分词器。
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "meta-llama/Llama-2-70b-chat-hf"
MAX_SEQ_LEN = 512
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, model_max_length=MAX_SEQ_LEN)
tokenizer.pad_token = tokenizer.eos_token
2. 准备校准数据加载器#
Quark 使用 PyTorch Dataloader 加载校准数据。 有关如何有效使用校准数据集的更多详细信息,请参阅 添加校准数据集。
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 1
NUM_CALIBRATION_DATA = 512
# Load the dataset and get calibration data.
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
text_data = dataset["text"][:NUM_CALIBRATION_DATA]
tokenized_outputs = tokenizer(text_data, return_tensors="pt",
padding=True, truncation=True, max_length=MAX_SEQ_LEN)
calib_dataloader = DataLoader(tokenized_outputs['input_ids'],
batch_size=BATCH_SIZE, drop_last=True)
3. 设置量化配置#
我们需要设置量化配置,您可以查看 quark 配置指南 了解更多详情。 在这里,我们在权重、激活、kv-cache 上使用 FP8 张量量化,量化算法为 AutoSmoothQuant。
注意
请注意,量化算法需要 JSON 配置文件,该配置文件位于 Quark Pytorch 示例 的 examples/torch/language_modeling/llm_ptq/models
目录下。 例如,Llama 的 AutoSmoothQuant 配置文件为 examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json
。
from quark.torch.quantization import (Config, QuantizationConfig,
FP8E4M3PerTensorSpec,
load_quant_algo_config_from_file)
# Define fp8/per-tensor/static spec.
FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max",
is_dynamic=False).to_quantization_spec()
# Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC.
global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC,
weight=FP8_PER_TENSOR_SPEC)
# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
kv_cache_quant_config = {name :
QuantizationConfig(input_tensors=global_quant_config.input_tensors,
weight=global_quant_config.weight,
output_tensors=KV_CACHE_SPEC)
for name in kv_cache_layer_names_for_llama}
layer_quant_config = kv_cache_quant_config.copy()
# Define algorithm config by config file.
LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE =
'examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json'
algo_config = load_quant_algo_config_from_file(LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE)
EXCLUDE_LAYERS = ["lm_head"]
quant_config = Config(
global_quant_config=global_quant_config,
layer_quant_config=layer_quant_config,
kv_cache_quant_config=kv_cache_quant_config,
exclude=EXCLUDE_LAYERS,
algo_config=algo_config)
4. 量化模型并导出#
然后我们可以应用量化。 量化后,我们需要先冻结量化模型,然后再导出。 请注意,我们需要以 HuggingFace safetensors
格式导出模型,您可以参考 HuggingFace 格式导出 了解更多导出格式详情。
import torch
from quark.torch import ModelQuantizer, ModelExporter
from quark.torch.export import ExporterConfig, JsonExporterConfig
# Apply quantization.
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize_model(model, calib_dataloader)
# Freeze quantized model to export.
freezed_model = quantizer.freeze(model)
# Define export config.
LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"]
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP
EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant"
exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR)
with torch.no_grad():
exporter.export_safetensors_model(freezed_model,
quant_config=quant_config, tokenizer=tokenizer)
5. 在 vLLM 中评估#
现在,您可以直接通过 LLM 入口点加载和运行 Quark 量化模型
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant",
kv_cache_dtype='fp8',quantization='quark')
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
或者,您可以使用 lm_eval
评估准确性
$ lm_eval --model vllm \
--model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant,kv_cache_dtype='fp8',quantization='quark' \
--tasks gsm8k
Quark 量化脚本#
除了上面的 Python API 示例外,Quark 还提供了一个 量化脚本,可以更方便地量化大型语言模型。 它支持使用各种不同的量化方案和优化算法来量化模型。 它可以导出量化模型并动态运行评估任务。 使用该脚本,上面的示例可以是
python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
--output_dir /path/to/output \
--quant_scheme w_fp8_a_fp8 \
--kv_cache_dtype fp8 \
--quant_algo autosmoothquant \
--num_calib_data 512 \
--model_export hf_format \
--tasks gsm8k