IO Processor 插件¶
IO Processor 插件是一项允许对池化(pooling)模型的模型输入和输出进行预处理和后处理的功能。其核心思想是允许用户向 vLLM 传递自定义输入,该输入会被转换为一个或多个模型提示词(prompts)并馈送至模型的 encode 方法。此类插件的一个潜在应用场景是使用 vLLM 生成多模态数据。例如,用户向 vLLM 馈送一张图像并获取一张输出图像。
在使用 IO Processor 插件执行推理时,提示词类型由插件定义,最终的请求输出亦然。vLLM 不会对输入/输出数据执行任何验证,确保馈送给模型并返回给用户的数据正确无误是插件的责任。目前,这些插件仅支持池化模型,并可通过 LLM 和 AsyncLLM 中的 encode 方法触发,或在在线服务模式下通过 /pooling 端点触发。
编写 IO Processor 插件¶
IO Processor 插件实现了 IOProcessor 接口
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
"""Abstract interface for pre/post-processing of engine I/O."""
def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
super().__init__()
self.vllm_config = vllm_config
def parse_data(self, data: object) -> IOProcessorInput:
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
return params or PoolingParams(task="plugin")
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
# We cannot guarantee outputs are returned in the same order they were
# fed to vLLM.
# Let's sort them by id before post_processing
sorted_output = sorted(
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id=request_id, **kwargs)
parse_data 方法用于验证用户数据并将其转换为 pre_process* 方法所预期的输入。merge_sampling_params 和 merge_pooling_params 方法将输入 SamplingParams 或 PoolingParams(如果有)与默认参数合并。pre_process* 方法接收验证后的插件输入,以生成用于常规推理的 vLLM 模型提示词。post_process* 方法接收 PoolingRequestOutput 对象作为输入,并生成自定义的插件输出。
一个使用 PrithviGeospatialMAE 模型生成 geotiff 图像的插件实现示例可在此处找到。请同时参考我们的在线( examples/pooling/plugin/prithvi_geospatial_mae_online.py)和离线( examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)推理示例。
使用 IO Processor 插件¶
IO Processor 插件在引擎启动时加载,指定加载插件名称的方法有两种:
- 通过 vLLM 的
EngineArgs:在用于初始化AsyncLLM的EngineArgs中设置io_processor_plugin参数。同样,在离线模式下将io_processor_plugin参数传递给LLM,或在服务模式下传递--io-processor-plugin参数,也能实现相同的效果。 - 通过模型 HF 配置:在模型配置文件 (config.json) 中添加一个
io_processor_plugin字段。
该顺序同时也决定了方法的优先级。即,通过 EngineArgs 设置的插件名称将覆盖模型 HF 配置 (config.json) 中指定的任何插件名称。