跳到内容

自定义 Logits Processors

重要

某些 logits processor 的设计更改仍在进行中,API 可能在不久的将来发生变化。我们希望尽快稳定此 API 部分。

“自定义”logits processor 由 vLLM 用户编写,并在初始化时加载到 vLLM 中,而无需修改或重新编译 vLLM 源代码。它与内置 logits processor 相反。

本文档演示了如何编写、加载和使用自定义 logits processor。

Logits Processors 背景

logits processor 调整下一个 token 的概率分布,通常旨在将模型引导到所需的行为类型。

在 vLLM 中,logits processor 以 batch 粒度运行。在给定的引擎步中,logits processor 消耗模型输出的 (num_requests) x (vocab_size) 的原始 logits 张量。对于所有启用了 logits processor 的请求,logits processor 会对 logits 张量的相应行应用转换,而其他行保持不变。然后将转换后的 logits 张量传递给 softmax。

创建自定义 Logits Processor

自定义 logits processor 必须继承自 vllm.v1.sample.logits_processor.LogitsProcessor 并定义(至少)以下方法:

  • validate_params(cls, sampling_params: SamplingParams):

    • 如果 SamplingParams 包含 logits processor 使用的无效参数(尤其是自定义参数),则引发 ValueError
    • 当请求发送到入口点时,validate_params() 将验证 SamplingParams 并拒绝带有无效参数的请求。
    • 注意: 实现 validate_params() 以防止自定义 logits processor 的参数无效非常重要。否则,带有无效参数的请求可能导致自定义 logits processor 出现意外行为。
  • __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)

    • vllm_config:引擎配置数据结构。
    • device:硬件加速器设备信息。
    • is_pin_memory:指示是否可用 pin memory 来支持 logits processor 实现的标志。
  • apply(self, logits: torch.Tensor) -> torch.Tensor:

    • 消耗一个 (num_requests) x (vocab_size) 的 logits 张量 (logits)。
    • 以 batch 粒度应用 logits processor 转换。
    • 返回一个转换后的 (num_requests) x (vocab_size) logits 张量。
    • 您可以原地修改或非原地修改输入的 logits processor;原地修改更节省内存。
  • is_argmax_invariant(self) -> bool:

    • 如果 logits processor 是 argmax 不变的(对于给定的请求,从不改变最高 logits 值 token ID),则返回 True;如果 logits processor 可能修改 argmax,则返回 False
    • is_argmax_invariant() 在启动时评估一次;如果为 True,当所有请求都使用贪婪采样时,vLLM 将跳过应用此 logits processor。
  • update_state(self, batch_update: Optional["BatchUpdate"]) -> None:

    • 消耗一个 BatchUpdate 数据结构,该结构表示当前引擎步开始时的持久 batch 状态更改。
    • 使用 BatchUpdate 的成员来更新 logits processor 的内部状态。
    • 注意: batch update 数据结构可能为 None,表示 batch 构成没有变化。在这种情况下,LogitsProcessor 可能仍希望根据它在添加时保留的更新后的 output_token_ids 列表来更新其状态。

vLLM 引擎如何构建 BatchUpdate 数据结构

重要

某些 logits processor 的设计更改仍在进行中。我们预计将来您在实现 logits processor 时不需要考虑 batch 状态更改,本节的信息将变得无关紧要。

logits processor 的 update_state() 实现应假定模型运行程序更新持久 batch 状态的模型(在此处以 BatchUpdate 抽象的形式表示):

  1. 识别当前引擎步中完成的请求的索引。

  2. 识别当前步中引入的新请求。

  3. 使用 Add 操作,按照被替换请求的索引从小到大的顺序,用新请求替换尽可能多的已完成请求。

  4. 基于新请求和已完成请求的相对数量。

    1. 如果新请求的数量与已完成请求的数量相同,则继续下一步。

    2. 如果新请求多于已完成请求: 应用 Add 操作,用剩余未替换已完成请求的新请求扩展 batch。为这些新请求分配连续的索引,从 current_max_batch_index + 1 开始。

    3. 如果新请求少于已完成请求。

      • 对未被新请求替换的已完成请求应用 Remove 操作。这些移除的请求索引必然大于上一步中被替换的已完成请求的最大索引。Remove 操作可能会使 batch 处于非连续状态。

      • “压缩”batch 使其连续: 从最低索引的空槽(由 Remove 操作引起)开始,应用一个从当前 batch 中最高的非空槽到该空槽的单向移动。按照空槽目标索引递增和非空槽源索引递减的顺序继续进行其他单向移动操作,直到 batch 连续。

      • 收缩 batch: 压缩 batch 的一个副作用是将 Remove 操作产生的空槽分组在一个连续块中,位于 batch 数组的末尾。因此,压缩后,更新 BatchUpdate.batch_size 以反映非空槽的数量。

  5. 重新排序 batch 以提高效率。根据 attention 后端实现和 batch 的当前特性,可以应用零个或多个 Swap Move 操作来重新排序 batch。

注意事项

  • logits processor 的 update_state() 方法必须按以下顺序处理 batch 更新操作:移除 (removes)、添加 (adds)、移动 (moves)。

  • Add 操作的索引参数指的是 Add 操作发生时的索引,即在任何 Move 操作之前。

    • 例如:如果一个请求在索引 5 处被添加,然后与索引 3 交换,那么 BatchUpdate.added 中的 Add 操作将与索引 5 相关联,而不是 3。
    • 换句话说,可以假定 Move 操作是在 Adds 和 Removes 之后应用的。
  • 可以假定 Move 操作是按照它们在 BatchUpdate.moved 中出现的顺序应用的。

  • 如果没有新/已完成请求,也没有 batch 重排,那么 logits processor 的 batch 更新将为 None

将自定义参数传递给自定义 Logits Processor

与内置 logits processor 不同,自定义 logits processor 可能需要配置参数,这些参数并未硬编码到 SamplingParams 或 vLLM 服务器 REST API 中。为了解决这个问题,自定义 logits processor 可以利用 vLLM 的 自定义参数 支持,从用户那里接收配置设置(尽管您也可以自由设计一个利用 SamplingParams 中现有字段的自定义 logits processor)。

示例自定义 Logits Processor 实现

下面的示例实现了一个自定义 logits processor,它消耗一个 (num_requests) x (vocab_size) 的 logits 张量,并将除了一个 token (target_token) 之外的所有 token 都掩蔽为 float(-inf)。对于任何未指定 target_token 的请求,该 logits processor 将被禁用。为了确定 logits processor 是否启用以及哪个 token 将保持不被掩蔽,logits processor 会检查 SamplingParams.extra_args 中与每个请求关联的 target_token 自定义参数。

示例自定义 logits processor 定义
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (BatchUpdate,
                                            LogitsProcessor,
                                            MoveDirectionality)

class DummyLogitsProcessor(LogitsProcessor):
    """Fake logit processor to support unit testing and examples"""

    @classmethod
    def validate_params(cls, params: SamplingParams):
        target_token: int | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is not None and not isinstance(target_token, int):
            raise ValueError(f"target_token value {target_token} is not int")

    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                is_pin_memory: bool):
        self.req_info: dict[int, int] = {}

    def is_argmax_invariant(self) -> bool:
        """Never impacts greedy sampling"""
        return False

    def update_state(self, batch_update: BatchUpdate | None):
        if not batch_update:
            return

        # Process added requests.
        for index, params, _, _ in batch_update.added:
            assert params is not None
            self.validate_params(params)
            if params.extra_args and (target_token :=
                                    params.extra_args.get("target_token")):
                self.req_info[index] = target_token
            else: 
                self.req_info.pop(index, None)

        if self.req_info:
            # Process removed requests.
            for index in batch_update.removed:
                self.req_info.pop(index, None)

            # Process moved requests, unidirectional move (a->b) and swap
            # (a<->b)
            for adx, bdx, direct in batch_update.moved:
                a_val = self.req_info.pop(adx, None)
                b_val = self.req_info.pop(bdx, None)
                if a_val is not None:
                    self.req_info[bdx] = a_val
                if direct == MoveDirectionality.SWAP and b_val is not None:
                    self.req_info[adx] = b_val

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.req_info:
            return logits

        # Save target values before modification
        cols = torch.tensor(
            list(self.req_info.values()), dtype=torch.long, device=logits.device
        )
        rows = torch.tensor(
            list(self.req_info.keys()), dtype=torch.long, device=logits.device
        )
        values_to_keep = logits[rows, cols].clone()

        # Mask all but target tokens
        logits[rows] = float('-inf')
        logits[rows, cols] = values_to_keep

        return logits

在本文档的其余部分,我们将使用 DummyLogitsProcessor 作为自定义 logits processor 的示例。

DummyLogitsProcessor.update_state() 实现使用 `self.req_info` 字典来维护 batch 请求的“稀疏”表示:只有指定了 target_token 值的请求才会在字典中有一个键。update_state() 根据 Add、Remove 和 Move 操作对持久 batch 的响应,调整存储的请求索引和 target_token 值(分别是 `self.req_info` 中的键和值)。

封装现有的请求级别 Logits Processor

尽管 vLLM 引擎以 batch 粒度应用 logits processor,但有些用户可能希望使用“请求级别”的 logits processor 实现与 vLLM 结合使用——即,一个作用于单个请求的实现。如果您的 logits processor 是为 vLLM 版本 0 开发的,这尤其会如此,当时它需要是 Callable(如 此处 所述),并符合以下类型注解:

RequestLogitsProcessor = Union[

    # (output token ids, logits tensor) -> logits tensor
    Callable[[list[int], Tensor], Tensor],

    # (prompt token ids, output token ids, logits tensor) -> logits tensor
    Callable[[list[int], list[int], Tensor], Tensor],
]

虽然请求级别的 logits processor 在 vLLM 引擎中明确**不支持**,但 vLLM **提供**了一个便捷的方法来封装现有的 Callable 请求级别 logits processor,并创建一个与 vLLM 兼容的 batch 级别 logits processor。Callable 必须符合上述类型注解;如果您的请求级别 logits processor 具有不同的接口,那么为了封装它,您可能需要修改它或实现一个额外的封装层以符合上述接口规范。

您可以通过继承 AdapterLogitsProcessor 来封装请求级别的 logits processor,如下面的示例所示(在此示例中,DummyPerReqLogitsProcessor 是您需要封装的请求级别 logits processor 的一个占位符)。

  • 重写 AdapterLogitsProcessor.validate_params(cls,params) 来验证请求的采样参数。

  • 重写 AdapterLogitsProcessor.is_argmax_invariant(self) 来准确反映您的请求级别 logits processor 是否会影响哪个 token 具有最高值的 logit。

  • 重写 AdapterLogitsProcessor.new_req_logits_processor(self,params) 来从 SamplingParams 实例创建一个新的请求级别 logits processor 实例。

封装请求级别 Logits Processor 的示例
...

from vllm.v1.sample.logits_processor import (
    AdapterLogitsProcessor, # Wrapper base-class
    RequestLogitsProcessor, # Request-level logitsproc type annotation
)

...

# Stand-in for your request-level logits processor:
class DummyPerReqLogitsProcessor:
    """The request-level logits processor masks out all logits except the
    token id identified by `target_token`"""

    def __init__(self, target_token: int) -> None:
        """Specify `target_token`"""
        self.target_token = target_token

    def __call__(
        self,
        output_ids: list[int],
        logits: torch.Tensor,
    ) -> torch.Tensor:
        val_to_keep = logits[self.target_token].item()
        logits[:] = float("-inf")
        logits[self.target_token] = val_to_keep
        return logits

...

# Example of wrapping the request-level logits processor:
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
    """Example of wrapping a fake request-level logit processor to create a
    batch-level logits processor"""

    @classmethod
    def validate_params(cls, params: SamplingParams):
        target_token: Any | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is not None and not isinstance(target_token, int):
            raise ValueError(
                f"target_token value {target_token} is not int"
            )

    def is_argmax_invariant(self) -> bool:
        return False

    def new_req_logits_processor(
        self,
        params: SamplingParams,
    ) -> Optional[RequestLogitsProcessor]:
        """This method returns a new request-level logits processor, customized
        to the `target_token` value associated with a particular request.

        Returns None if the logits processor should not be applied to the
        particular request. To use the logits processor the request must have
        a "target_token" custom argument with an integer value.

        Args:
        params: per-request sampling params

        Returns:
        `Callable` request logits processor, or None
        """
        target_token: Any | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is None:
            return None
        return DummyPerReqLogitsProcessor(target_token)

注意

您的 new_req_logits_processor() 重写可以返回 None,以指示不应将封装的 logits processor 应用于当前请求。

一旦您创建了一个自定义子类(例如 WrappedPerReqLogitsProcessor)来封装您的请求级别 logits processor,您就可以通过下一节中描述的任何方法将其传递给 vLLM。

在 vLLM 中加载自定义 Logits Processor 的方法

Logits processor 在初始化时加载。重要的是,加载的 logits processor 集合在 vLLM 引擎完成加载后不能被修改,也不能为单个请求按需加载新的 logits processor。

本节详细介绍了让您的 logits processor 对 vLLM 可见并触发 vLLM 加载您的 logits processor 的各种方法。

方法 1:在初始化时将自定义 Logits Processor 的完全限定类名 (FQCN) 传递给 vLLM

此方法在离线和在线 vLLM 使用场景中都受支持。自定义 logits processor 的 FQCN(形式为 dotted.path.to.module:ClassName)可以作为参数传递给 LLMAsyncLLM Python 构造函数,或者作为 CLI 参数传递给 vllm serve,语法如下:

vllm serve ... --logits_processors <logits processor 1> <logits processor 2> ...

FQCN 的唯一要求是:

  1. Python 的 importlib.import_module() 必须能够解析 FQCN 的点状路径并将其作为模块加载。

  2. FQCN 的类名部分必须能够从加载的模块中导入。

  3. FQCN 指向的对象必须是 LogitsProcessor 的子类。

请参阅下面的示例。

在 Python 中将自定义 logits processor FQCN 传递给 LLM
# Pass in FQCN
llm = LLM(
    model="facebook/opt-125m",
    logits_processors=["your.module.path:DummyLogitsProcessor"],
)
在 Python 中将自定义 logits processor FQCN 传递给 AsyncLLM
# Pass in FQCN
engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                              logits_processors=["your.module.path:DummyLogitsProcessor"])
async_llm = AsyncLLM.from_engine_args(engine_args)
通过 CLI 将自定义 logits processor FQCN 传递给 vLLM 服务器
vllm serve facebook/opt-125m --logits_processors your.module.path:DummyLogitsProcessor

方法 2:自动检测安装在您的 Python 环境中作为入口点的自定义 Logits Processors

setuptools 可以使已安装的包成为其他 Python 程序的插件,通过称为“入口点”的元数据片段。

在初始化期间,vLLM 会自动扫描 vllm.logits_processors 入口点组,并加载它找到的所有已安装的 logits processor。

假设您开发了一个包含自定义 logits processor 的 Python 包。您可以通过为每个 logits processor 在您的 logits processor Python 包中添加一个唯一的入口点来将其暴露给 vLLM。下面的示例展示了如何向项目的 pyproject.toml 文件添加一个入口点:

将自定义 logits processor 作为 Python 入口点公开
[project.entry-points."vllm.logits_processors"]
dummy_logits_processor = "your.module.path:DummyLogitsProcessor"

一旦您的包被安装,每当 vLLM 初始化时,您的自定义 logits processor 都将被自动加载。如果您的 logits processor 是通过入口点公开的,您**不需要**在初始化时显式地将自定义 logits processor 传递给 LLMAsyncLLM 构造函数或 vLLM 服务器。

注意

vLLM 将**始终**加载**所有**通过 vllm.logits_processors 分组公开的 logits processor。

方法 3 (仅限离线):将 Python 类对象传递给 vLLM 构造函数

您可以将一个或多个自定义 logits processor 类对象传递给 LLMAsyncLLM 构造函数。此选项非常灵活,因为 logits processor 类可以是 (1) 在与 LLMAsyncLLM 实例化的相同 Python 源文件中本地定义的,或者 (2) 从 Python 包导入的。

在 Python 中将自定义 logits processor 类对象传递给 LLMAsyncLLM
# Import custom logits processor
from some.module import DummyLogitsProcessor

# ...or...

# Define custom logits processor locally
from vllm.v1.sample.logits_processor import LogitsProcessor

class DummyLogitsProcessor(LogitsProcessor):
    # See DummyLogitsProcessor implementation above
    ...

# Pass class object to LLM constructor
llm = LLM(
    model="facebook/opt-125m",
    logits_processors=[DummyLogitsProcessor],
)

# Pass class object to AsyncLLM constructor
engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                              logits_processors=[DummyLogitsProcessor])
async_llm = AsyncLLM.from_engine_args(engine_args)

针对请求调用自定义 Logits Processor

自定义 logits processor 的设计决定了是否必须为给定请求启用/禁用 logits processor,以及必须提供哪些参数来配置 logits processor。

下面的示例展示了用户如何向 DummyLogitsProcessor 传递自定义参数 (target_token) 以 (1) 为特定请求启用 logits processor 并 (2) 控制 logits processor 的行为。

vLLM REST API:配置请求的自定义 logits processor
curl https://:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen2.5-1.5B-Instruct",
        ...
        "vllm_xargs": {"target_token": 67}
    }'
OpenAI SDK:配置请求的自定义 logits processor
batch = await client.completions.create(
    model="Qwen/Qwen2.5-1.5B-Instruct",
    ...,
    extra_body={
        "vllm_xargs": {
            "target_token": 67
        }
    }
)
离线:为 LLM 请求配置自定义 logits processor
outputs_logitproc = llm.generate("your prompt", 
                                 SamplingParams(...,
                                    extra_args={"target_token": 67}))
离线:为 AsyncLLM 请求配置自定义 logits processor
async for out in engine.generate(request_id="your request id",
                                 prompt="your prompt",
                                 sampling_params=SamplingParams(...,
                                    extra_args={"target_token": 67})):

    # Process async request outputs
    ...

编写自定义 Logits Processors 的最佳实践

一旦 vLLM 在初始化期间加载了 logits processor,vLLM 将在每个引擎步中对该 logits processor 调用 update_state()apply()。这两个方法都作用于当前位于 vLLM 持久 batch 中的所有请求。因此,高效实现这些方法非常重要。

  • 考虑到 logits processor 以 batch 粒度运行,请编写高效的 apply()update_state() 实现。

    • 例如,您可能能够使用高效的向量化操作来实现 apply() 或在 update_state() 中更新内部状态向量。
    • 但是,如果您认为某个 logits processor 可能不经常使用,那么使用“稀疏”表示请求状态可能是合适的,即该类可以使用一个字典来表示请求配置,该字典仅存储启用 logits processor 的请求的元数据。
    • 注意: 封装的请求级别 logits processor 不需要实现 apply()update_state();默认的 AdapterLogitsProcessor.update_state() 实现维护请求状态的稀疏表示,其中 new_req_logits_processor() 返回 None 的请求在基类状态字典中不被表示。AdapterLogitsProcessor.apply() 的默认实现将请求级别 logits processor 顺序应用于输入 logits 的每一行,并组装输出 logits 张量。如果此 AdapterLogitsProcessor 默认实现的性能不足,则避免封装您的请求级别 logits processor,而是将其重新实现为具有优化后的 apply()update_state() 实现的 LogitsProcessor 子类,这些实现以 batch 粒度运行。
  • 由 logits processor 作者决定:

    1. 配置 logits processor 对该请求行为的每个请求属性。 您自定义 logits processor 的 update_state() 重写决定了如何将 SamplingParams 字段映射到 logits processor 状态。

      • 注意: 对于封装的请求级别 logits processor,new_req_logits_processor() 决定了如何使用 SamplingParams 字段来初始化请求级别 logits processor 实例。
    2. logits processor 在每个请求基础上启用或不启用的条件。 除非您的目的是让自定义 logits processor 始终作用于所有请求,否则您应该以这样一种方式编写您的 logits processor,即有可能为特定请求禁用 logits processor,例如通过将参数默认为 None 或传入特定的无操作参数值(例如 0.0)。对于禁用 logits processor 的请求,尽量节省计算和内存。

      • 注意: 对于封装的请求级别 logits processor,默认的 AdapterLogitsProcessor.update_state() 实现确保在 new_req_logits_processor() 为该请求返回 None 时禁用请求级别 logits processor。
    3. logits processor 在 batch 级别被短路的条件。 即使您已经定义了在请求级别禁用自定义 logits processor 的方法,也很难将其转化为计算节省,例如,如果您的 update_state()apply() 实现使用了在单个命令中作用于整个持久 batch 的高效向量化实现。例如,您不能仅仅因为一个请求禁用了 logits processor 就跳过 apply() 中的整个向量化操作。为了在没有运行请求使用自定义 logits processor 的边缘情况下节省计算,我们建议将 apply() 设计为在所有请求都禁用 logits processor 时返回未修改的输入张量。同样,请考虑在没有请求启用 logits processor 的情况下是否可以跳过 update_state() 中的步骤。

      • 此外,一个在 update_state() 中节省计算的简单方法是,当 batch_updateNone 时提前退出。

      • 注意: 对于封装的请求级别 logits processor,AdapterLogitsProcessor 基类默认实现了上述优化。

  • 确保 logits processor 的 update_state 方法丢弃已完成请求的信息(例如,被 Add 操作替换或被 Remove 操作处理的请求)。

    • 注意: 对于封装的请求级别 logits processor,AdapterLogitsProcessor 基类默认处理此问题。
  • 如果 logits processor 具有一致的行为,is_argmax_invariant() 可以硬编码为 TrueFalse。但是,argmax 不变性也可以通过编程方式确定(例如,如果您的 logits processor 是用户可定制的,以某种方式影响 logits processor 是否是 argmax 不变的)。因此,is_argmax_invariant() 不是一个类方法。