LoRA 量化推理

来源 examples/offline_inference/lora_with_quantization_inference.py

# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to use LoRA with different quantization techniques
for offline inference.

Requires HuggingFace credentials for access.
"""

import gc
from typing import Optional

import torch
from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
    lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
    return [
        # this is an example of using quantization without LoRA
        (
            "My name is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            None,
        ),
        # the next three examples use quantization with LoRA
        (
            "my name is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            LoRARequest("lora-test-1", 1, lora_path),
        ),
        (
            "The capital of USA is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            LoRARequest("lora-test-2", 1, lora_path),
        ),
        (
            "The capital of France is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            LoRARequest("lora-test-3", 1, lora_path),
        ),
    ]


def process_requests(
    engine: LLMEngine,
    test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(
                str(request_id), prompt, sampling_params, lora_request=lora_request
            )
            request_id += 1

        request_outputs: list[RequestOutput] = engine.step()
        for request_output in request_outputs:
            if request_output.finished:
                print("----------------------------------------------------")
                print(f"Prompt: {request_output.prompt}")
                print(f"Output: {request_output.outputs[0].text}")


def initialize_engine(
    model: str, quantization: str, lora_repo: Optional[str]
) -> LLMEngine:
    """Initialize the LLMEngine."""

    engine_args = EngineArgs(
        model=model,
        quantization=quantization,
        enable_lora=True,
        max_lora_rank=64,
        max_loras=4,
    )
    return LLMEngine.from_engine_args(engine_args)


def main():
    """Main function that sets up and runs the prompt processing."""

    test_configs = [
        # QLoRA (https://arxiv.org/abs/2305.14314)
        {
            "name": "qlora_inference_example",
            "model": "huggyllama/llama-7b",
            "quantization": "bitsandbytes",
            "lora_repo": "timdettmers/qlora-flan-7b",
        },
        {
            "name": "AWQ_inference_with_lora_example",
            "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
            "quantization": "awq",
            "lora_repo": "jashing/tinyllama-colorist-lora",
        },
        {
            "name": "GPTQ_inference_with_lora_example",
            "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
            "quantization": "gptq",
            "lora_repo": "jashing/tinyllama-colorist-lora",
        },
    ]

    for test_config in test_configs:
        print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
        engine = initialize_engine(
            test_config["model"], test_config["quantization"], test_config["lora_repo"]
        )
        lora_path = snapshot_download(repo_id=test_config["lora_repo"])
        test_prompts = create_test_prompts(lora_path)
        process_requests(engine, test_prompts)

        # Clean up the GPU memory for the next test
        del engine
        gc.collect()
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()