跳到内容

KV 加载失败恢复测试

来源 https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/kv_load_failure_recovery

此示例基于 examples/offline_inference 中的 disaggregated-prefill-v1 示例。

它演示了 vLLM 在同步和异步加载模式下从 KV 加载失败中恢复的能力。目标是验证 vLLM 是否能正确识别无效的 KV 块,重新调度受影响的请求,并确保输出成功且一致。

文件

  • prefill_example.py – 执行预填充阶段并保存 KV 数据(与 disaggregated-prefill-v1 中的相同)。
  • decode_example.py – 执行解码阶段。接受
    • --simulate-failure: 使用自定义连接器模拟 KV 加载失败。
    • --async-load: 启用异步 KV 加载模式。
  • load_recovery_example_connector.py – 定义 LoadRecoveryExampleConnector,它是 ExampleConnector 的子类,通过在第一个解码请求中加载块失败来模拟丢失或损坏的外部 KV 块。
  • run.sh – 协调测试:运行预填充阶段,然后是三个解码阶段

    1. 正常解码(基线)。
    2. 带有模拟同步 KV 加载失败的解码。
    3. 带有模拟异步 KV 加载失败的解码。

    最后,它比较基线与恢复的输出,以验证正确性。

工作原理

  • 该测试通过 KVTransferConfig.kv_connector_module_path 动态加载 LoadRecoveryExampleConnector,从而无需修改原始连接器即可控制模拟加载失败。
  • 预计会触发失败的解码阶段来触发 vLLM 中的恢复逻辑,从而产生与基线解码相同的输出。
  • 如果恢复失败,脚本将打印输出不匹配的统一 diff 并以错误退出。

用法

```bash ./run.sh

示例材料

decode_example.py

``````py

SPDX-License-Identifier: Apache-2.0

SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse

from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig

def read_prompts(): """从 prefill_output.txt 读取提示""" prompts = [] try: with open("prefill_output.txt") as f: for line in f: prompts.append(line.strip()) print(f"从 prefill_output.txt 加载了 {len(prompts)} 个提示") return prompts except FileNotFoundError: print("错误:未找到 prefill_output.txt 文件") exit(-1)

def main(): prompts = read_prompts() sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--simulate-failure", action="store_true", help="Simulate KV load failure."
)
parser.add_argument(
    "--async-load", action="store_true", help="Simulate async KV load"
)
args = parser.parse_args()

if args.simulate_failure:
    ktc = KVTransferConfig(
        kv_connector="LoadRecoveryExampleConnector",
        kv_role="kv_both",
        kv_connector_extra_config={
            "shared_storage_path": "local_storage",
            "async_load": args.async_load,
        },
        kv_connector_module_path="load_recovery_example_connector",
    )
    out_file = (
        "async_decode_recovered_output.txt"
        if args.async_load
        else "sync_decode_recovered_output.txt"
    )
else:
    ktc = KVTransferConfig(
        kv_connector="ExampleConnector",
        kv_role="kv_both",
        kv_connector_extra_config={
            "shared_storage_path": "local_storage",
        },
    )
    out_file = "decode_output.txt"

llm = LLM(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
    gpu_memory_utilization=0.8,
    max_num_batched_tokens=64,
    max_num_seqs=16,
    kv_transfer_config=ktc,
)

outputs = llm.generate(prompts, sampling_params)

sep_str = "-" * 30
with open(out_file, "w", encoding="utf-8") as f:
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
        print(out_str)
        print(sep_str)
        f.write(out_str)
        f.write(sep_str)

if name == "main": main()

``````

load_recovery_example_connector.py

``````py

SPDX-License-Identifier: Apache-2.0

SPDX-FileCopyrightText: Copyright contributors to the vLLM project

ruff: noqa: E501

import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING

from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( ExampleConnector, ExampleConnectorMetadata, ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request

if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput

logger = logging.getLogger() logging.basicConfig(level=logging.INFO)

@dataclass class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata): req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)

@classmethod
def from_base(cls, base: ExampleConnectorMetadata):
    return cls(requests=base.requests)

class LoadRecoveryExampleConnector(ExampleConnector): def init(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().init(vllm_config=vllm_config, role=role) self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( "async_load", False ) self._invalid_block_ids: set = None self._seen_requests: set = set() self._req_to_block_ids: dict[str, list[int]] = dict()

def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
    assert isinstance(connector_metadata, LoadRecoveryExampleConnectorMetadata)
    index, failed_request = next(
        (
            (i, x)
            for i, x in enumerate(connector_metadata.requests)
            if not x.is_store
        ),
        (None, None),
    )
    if index is not None:
        del connector_metadata.requests[index]
        self._invalid_block_ids = set(
            (
                failed_request.slot_mapping[:: self._block_size] // self._block_size
            ).tolist()
        )
        logger.info(
            "Simulating failure to load all KV blocks for the "
            "first load request. Total blocks: %d",
            len(self._invalid_block_ids),
        )
    super().bind_connector_metadata(connector_metadata)

def clear_connector_metadata(self) -> None:
    self._invalid_block_ids = None
    super().clear_connector_metadata()

def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
    if self._async_load and forward_context.attn_metadata is None:
        # Bypass  sanity check in super().start_load_kv
        forward_context.attn_metadata = "None"

    super().start_load_kv(forward_context, **kwargs)

def get_finished(
    self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
    if self._async_load:
        meta = self._get_connector_metadata()
        assert isinstance(meta, LoadRecoveryExampleConnectorMetadata)
        if meta.req_to_block_ids:
            return None, set(meta.req_to_block_ids)

    return None, None

def get_block_ids_with_load_errors(self) -> set[int]:
    return self._invalid_block_ids

def get_num_new_matched_tokens(
    self,
    request: Request,
    num_computed_tokens: int,
) -> tuple[int, bool]:
    if request.request_id in self._seen_requests:
        return 0, False

    self._seen_requests.add(request.request_id)

    num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
    return num_tokens, self._async_load and num_tokens > 0

def update_state_after_alloc(
    self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
    """
    Update KVConnector state after block allocation.

    If blocks were allocated, add to _requests_need_load,
    such that we load the KVs in the next forward pass.
    """
    super().update_state_after_alloc(request, blocks, num_external_tokens)

    if num_external_tokens > 0:
        self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]

def build_connector_meta(
    self,
    scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
    if not self._async_load:
        base = super().build_connector_meta(scheduler_output)
        meta = LoadRecoveryExampleConnectorMetadata.from_base(base)
    else:
        meta = LoadRecoveryExampleConnectorMetadata()
        if self._requests_need_load:
            for req_id, request in self._requests_need_load.items():
                meta.add_request(
                    token_ids=request.prompt_token_ids,
                    block_ids=self._req_to_block_ids[req_id],
                    block_size=self._block_size,
                    is_store=False,
                    mm_hashes=[],
                )
            # Clear state
            self._requests_need_load.clear()
    meta.req_to_block_ids = self._req_to_block_ids
    self._req_to_block_ids = dict()
    return meta

``````

prefill_example.py

``````py

SPDX-License-Identifier: Apache-2.0

SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig

def read_prompts(): context = "Hi " * 1000 context2 = "Hey " * 500 return [ context + "Hello, my name is", context + "The capital of France is", context2 + "Your name is", context2 + "The capital of China is", ]

def main(): prompts = read_prompts()

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

llm = LLM(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
    gpu_memory_utilization=0.8,
    kv_transfer_config=KVTransferConfig(
        kv_connector="ExampleConnector",
        kv_role="kv_both",
        kv_connector_extra_config={"shared_storage_path": "local_storage"},
    ),
)  # , max_model_len=2048, max_num_batched_tokens=2048)

# 1ST generation (prefill instance)
outputs = llm.generate(
    prompts,
    sampling_params,
)

new_prompts = []
print("-" * 30)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    new_prompts.append(prompt + generated_text)
    print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
    print("-" * 30)

# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
    for prompt in new_prompts:
        f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")

if name == "main": main()

``````

run.sh

``````sh

!/bin/bash

常量

SHARED_STORAGE_DIR="local_storage" PREFILL_OUTPUT="prefill_output.txt" DECODE_OUTPUT="decode_output.txt" SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt" ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"

清理

rm -rf "\(SHARED_STORAGE_DIR" rm -f "\)PREFILL_OUTPUT" "\(DECODE_OUTPUT" "\)SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"

运行推理示例

VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load

比较输出

if ! cmp -s "\(DECODE_OUTPUT" "\)SYNC_DECODE_RECOVERED_OUTPUT"; then echo "❌ 输出不同:同步恢复失败。" diff -u "\(DECODE_OUTPUT" "\)SYNC_DECODE_RECOVERED_OUTPUT" exit 1 fi

if ! cmp -s "\(DECODE_OUTPUT" "\)ASYNC_DECODE_RECOVERED_OUTPUT"; then echo "❌ 输出不同:异步恢复失败。" diff -u "\(DECODE_OUTPUT" "\)ASYNC_DECODE_RECOVERED_OUTPUT" exit 1 fi

echo "✅ 输出匹配:恢复成功。"

``````