KV 加载失败恢复测试¶
来源 https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/kv_load_failure_recovery。
此示例基于 examples/offline_inference 中的 disaggregated-prefill-v1 示例。
它演示了 vLLM 在同步和异步加载模式下从 KV 加载失败中恢复的能力。目标是验证 vLLM 是否能正确识别无效的 KV 块,重新调度受影响的请求,并确保输出成功且一致。
文件¶
prefill_example.py– 执行预填充阶段并保存 KV 数据(与disaggregated-prefill-v1中的相同)。decode_example.py– 执行解码阶段。接受--simulate-failure: 使用自定义连接器模拟 KV 加载失败。--async-load: 启用异步 KV 加载模式。
load_recovery_example_connector.py– 定义LoadRecoveryExampleConnector,它是ExampleConnector的子类,通过在第一个解码请求中加载块失败来模拟丢失或损坏的外部 KV 块。-
run.sh– 协调测试:运行预填充阶段,然后是三个解码阶段- 正常解码(基线)。
- 带有模拟同步 KV 加载失败的解码。
- 带有模拟异步 KV 加载失败的解码。
最后,它比较基线与恢复的输出,以验证正确性。
工作原理¶
- 该测试通过
KVTransferConfig.kv_connector_module_path动态加载LoadRecoveryExampleConnector,从而无需修改原始连接器即可控制模拟加载失败。 - 预计会触发失败的解码阶段来触发 vLLM 中的恢复逻辑,从而产生与基线解码相同的输出。
- 如果恢复失败,脚本将打印输出不匹配的统一 diff 并以错误退出。
用法¶
```bash ./run.sh
示例材料¶
decode_example.py
``````py
SPDX-License-Identifier: Apache-2.0¶
SPDX-FileCopyrightText: Copyright contributors to the vLLM project¶
import argparse
from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig
def read_prompts(): """从 prefill_output.txt 读取提示""" prompts = [] try: with open("prefill_output.txt") as f: for line in f: prompts.append(line.strip()) print(f"从 prefill_output.txt 加载了 {len(prompts)} 个提示") return prompts except FileNotFoundError: print("错误:未找到 prefill_output.txt 文件") exit(-1)
def main(): prompts = read_prompts() sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
parser = argparse.ArgumentParser()
parser.add_argument(
"--simulate-failure", action="store_true", help="Simulate KV load failure."
)
parser.add_argument(
"--async-load", action="store_true", help="Simulate async KV load"
)
args = parser.parse_args()
if args.simulate_failure:
ktc = KVTransferConfig(
kv_connector="LoadRecoveryExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
"async_load": args.async_load,
},
kv_connector_module_path="load_recovery_example_connector",
)
out_file = (
"async_decode_recovered_output.txt"
if args.async_load
else "sync_decode_recovered_output.txt"
)
else:
ktc = KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
)
out_file = "decode_output.txt"
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=ktc,
)
outputs = llm.generate(prompts, sampling_params)
sep_str = "-" * 30
with open(out_file, "w", encoding="utf-8") as f:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)
if name == "main": main()
``````
load_recovery_example_connector.py
``````py
SPDX-License-Identifier: Apache-2.0¶
SPDX-FileCopyrightText: Copyright contributors to the vLLM project¶
ruff: noqa: E501¶
import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING
from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( ExampleConnector, ExampleConnectorMetadata, ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request
if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput
logger = logging.getLogger() logging.basicConfig(level=logging.INFO)
@dataclass class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata): req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
@classmethod
def from_base(cls, base: ExampleConnectorMetadata):
return cls(requests=base.requests)
class LoadRecoveryExampleConnector(ExampleConnector): def init(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().init(vllm_config=vllm_config, role=role) self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( "async_load", False ) self._invalid_block_ids: set = None self._seen_requests: set = set() self._req_to_block_ids: dict[str, list[int]] = dict()
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, LoadRecoveryExampleConnectorMetadata)
index, failed_request = next(
(
(i, x)
for i, x in enumerate(connector_metadata.requests)
if not x.is_store
),
(None, None),
)
if index is not None:
del connector_metadata.requests[index]
self._invalid_block_ids = set(
(
failed_request.slot_mapping[:: self._block_size] // self._block_size
).tolist()
)
logger.info(
"Simulating failure to load all KV blocks for the "
"first load request. Total blocks: %d",
len(self._invalid_block_ids),
)
super().bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
self._invalid_block_ids = None
super().clear_connector_metadata()
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
if self._async_load and forward_context.attn_metadata is None:
# Bypass sanity check in super().start_load_kv
forward_context.attn_metadata = "None"
super().start_load_kv(forward_context, **kwargs)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
if self._async_load:
meta = self._get_connector_metadata()
assert isinstance(meta, LoadRecoveryExampleConnectorMetadata)
if meta.req_to_block_ids:
return None, set(meta.req_to_block_ids)
return None, None
def get_block_ids_with_load_errors(self) -> set[int]:
return self._invalid_block_ids
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
if request.request_id in self._seen_requests:
return 0, False
self._seen_requests.add(request.request_id)
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
return num_tokens, self._async_load and num_tokens > 0
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
super().update_state_after_alloc(request, blocks, num_external_tokens)
if num_external_tokens > 0:
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
def build_connector_meta(
self,
scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
if not self._async_load:
base = super().build_connector_meta(scheduler_output)
meta = LoadRecoveryExampleConnectorMetadata.from_base(base)
else:
meta = LoadRecoveryExampleConnectorMetadata()
if self._requests_need_load:
for req_id, request in self._requests_need_load.items():
meta.add_request(
token_ids=request.prompt_token_ids,
block_ids=self._req_to_block_ids[req_id],
block_size=self._block_size,
is_store=False,
mm_hashes=[],
)
# Clear state
self._requests_need_load.clear()
meta.req_to_block_ids = self._req_to_block_ids
self._req_to_block_ids = dict()
return meta
``````
prefill_example.py
``````py
SPDX-License-Identifier: Apache-2.0¶
SPDX-FileCopyrightText: Copyright contributors to the vLLM project¶
from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig
def read_prompts(): context = "Hi " * 1000 context2 = "Hey " * 500 return [ context + "Hello, my name is", context + "The capital of France is", context2 + "Your name is", context2 + "The capital of China is", ]
def main(): prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
if name == "main": main()
``````
run.sh
``````sh
!/bin/bash¶
常量¶
SHARED_STORAGE_DIR="local_storage" PREFILL_OUTPUT="prefill_output.txt" DECODE_OUTPUT="decode_output.txt" SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt" ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
清理¶
rm -rf "\(SHARED_STORAGE_DIR" rm -f "\)PREFILL_OUTPUT" "\(DECODE_OUTPUT" "\)SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
运行推理示例¶
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
比较输出¶
if ! cmp -s "\(DECODE_OUTPUT" "\)SYNC_DECODE_RECOVERED_OUTPUT"; then echo "❌ 输出不同:同步恢复失败。" diff -u "\(DECODE_OUTPUT" "\)SYNC_DECODE_RECOVERED_OUTPUT" exit 1 fi
if ! cmp -s "\(DECODE_OUTPUT" "\)ASYNC_DECODE_RECOVERED_OUTPUT"; then echo "❌ 输出不同:异步恢复失败。" diff -u "\(DECODE_OUTPUT" "\)ASYNC_DECODE_RECOVERED_OUTPUT" exit 1 fi
echo "✅ 输出匹配:恢复成功。"
``````