量化 KV Cache¶
FP8 KV Cache 概述¶
高效的内存使用对于运行大语言模型至关重要。将 KV (Key-Value) Cache 量化为 FP8 格式可以显著减少其内存占用。这种优化使您能够在内存中存储更多的 token,从而提高吞吐量并支持更长的上下文窗口。
注意:当在 FP8 KV Cache 中使用 Flash Attention 3 后端时,Attention 操作也会在量化 (FP8) 域中执行。在此配置下,除了 Key 和 Value 外,Query 也会被量化为 FP8。
支持的 FP8 KV-Cache 量化方案¶
vLLM 支持两种主要的 FP8 KV-Cache 量化策略
- 张量级量化 (Per-tensor quantization)
对每个 Q、K 和 V 张量分别应用单个缩放比例。 (q/k/v_scale = [1]) - 注意力头级量化 (Per-attention-head quantization)
每个缩放比例对应一个注意力头:q_scale = [num_heads],k/v_scale = [num_kv_heads]。
注意
注意力头级量化目前仅在使用 Flash Attention 后端时可用,并且需要 llm-compressor 提供的校准路径。
缩放校准方法¶
您可以使用三种不同的方法来配置 vLLM 中量化缩放比例的计算方式
-
无校准(默认缩放)
所有量化缩放比例均设为1.0。
配置方式
-
随机 Token 校准(即时计算)
在预热期间,缩放比例会根据单批次随机 token 自动进行估计,随后固定下来。
配置方式
-
[推荐] 使用数据集进行校准 (通过 llm-compressor)
使用精心挑选的校准数据集来估计缩放比例,以获得最大精度。
这需要 llm-compressor 库。
请查看下方示例!
其他的 kv_cache_dtype 选项¶
kv_cache_dtype="auto": 使用模型的默认数据类型kv_cache_dtype="fp8_e4m3": 在 CUDA 11.8+ 和 ROCm (AMD GPU) 上支持kv_cache_dtype="fp8_e5m2": 在 CUDA 11.8+ 上支持
示例¶
1. 无校准 (kv_cache_dtype="fp8", calculate_kv_scales=False)¶
所有量化缩放比例均设为 1.0。
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=False,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
2. 随机 Token 校准 (kv_cache_dtype="fp8", calculate_kv_scales=True)¶
在预热期间,缩放比例会根据单批次 token 自动进行估计。
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=True,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
3. [推荐] 使用数据集进行校准 (配合 llm-compressor)¶
为了获得最高质量的量化效果,我们建议使用 llm-compressor 在数据集上进行校准。这可以启用诸如注意力头级量化等高级策略。
安装所需的包¶
示例:将 Llama Attention 和 KV Cache 量化为 FP8¶
"""
Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy)
using llm-compressor one-shot calibration.
"""
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
# -----------------------------
# Config
# -----------------------------
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
STRATEGY = "tensor" # or "attn_head"
NUM_CALIB_SAMPLES = 512 # Good starting value
MAX_SEQ_LEN = 2048
# -----------------------------
# Helpers
# -----------------------------
def process_and_tokenize(example, tokenizer: AutoTokenizer):
"""Convert chat messages to tokens."""
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQ_LEN,
truncation=True,
add_special_tokens=False,
)
def build_recipe(strategy: str) -> QuantizationModifier:
fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy)
return QuantizationModifier(
config_groups={
"attention": QuantizationScheme(
targets=["LlamaAttention"], # Quantize queries: q_scale
input_activations=fp8_args,
)
},
kv_cache_scheme=fp8_args, # Quantize KV cache: k/v_scale
)
# -----------------------------
# Main
# -----------------------------
def main():
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]")
ds = ds.shuffle(seed=42)
ds = ds.map(
lambda ex: process_and_tokenize(ex, tokenizer),
remove_columns=ds.column_names,
)
recipe = build_recipe(STRATEGY)
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQ_LEN,
num_calibration_samples=NUM_CALIB_SAMPLES,
)
save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}"
model.save_pretrained(save_dir, save_compressed=True)
tokenizer.save_pretrained(save_dir)
if __name__ == "__main__":
main()
有关更详细和最新的示例,请参见 llm-compressor 官方示例。