跳到内容

自定义 Logits 处理器

重要

部分 logits 处理器设计变更仍在进行中,API 可能在不久的将来发生变化。我们希望尽快稳定这部分 API。

“自定义”logits 处理器由 vLLM 用户编写,并在初始化时加载到 vLLM 中,无需修改或重新编译 vLLM 源代码。它与内置的 logits 处理器相对。

本文档介绍了如何编写、加载和使用自定义 logits 处理器。

Logits 处理器背景

Logits 处理器用于调整下一个 token 的概率分布,通常目的是引导模型实现期望的行为。

在 vLLM 中,logits 处理器以批次(batch)为粒度进行操作。在引擎的每个步骤中,logits 处理器会消耗模型输出的 (num_requests) x (vocab_size) 原始 logits 张量。对于所有启用了该 logits 处理器的请求,处理器会对 logits 张量的相应行进行转换,而不修改其他行。转换后的 logits 张量随后被传递给 softmax 函数。

创建自定义 Logits 处理器

自定义 logits 处理器必须继承 vllm.v1.sample.logits_processor.LogitsProcessor 并(至少)定义以下方法:

  • validate_params(cls, sampling_params: SamplingParams):

    • 如果 SamplingParams 包含 logits 处理器使用的非法参数(特别是自定义参数),则抛出 ValueError
    • 当请求发送到入口点时,validate_params() 将验证 SamplingParams 并拒绝带有非法参数的请求。
    • 注意:实现 validate_params() 以防止自定义 logits 处理器接收非法参数非常重要。否则,带有非法参数的请求可能会导致自定义 logits 处理器出现意外行为。
  • __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)

    • vllm_config:引擎配置数据结构
    • device:硬件加速器设备信息
    • is_pin_memory:指示是否可以使用锁页内存(pin memory)来支持 logits 处理器实现的标志
  • apply(self, logits: torch.Tensor) -> torch.Tensor:

    • 消耗一个 (num_requests) x (vocab_size) 的 logits 张量 (logits)
    • 以批次粒度应用 logits 处理器转换
    • 返回一个转换后的 (num_requests) x (vocab_size) logits 张量
    • 您可以原地(in-place)或非原地修改输入 logits;原地修改更节省内存
  • is_argmax_invariant(self) -> bool:

    • 如果 logits 处理器是 argmax 不变的(即对于给定请求,绝不会改变概率最高的 token ID),则返回 True;如果 logits 处理器可能会修改 argmax,则返回 False
    • is_argmax_invariant() 在启动时评估一次;如果为 True,当所有请求都使用贪婪采样时,vLLM 将跳过在给定步骤中应用此 logits 处理器
  • update_state(self, batch_update: Optional["BatchUpdate"]) -> None:

    • 在当前引擎步骤开始时,消耗一个表示持续批次状态变更的 BatchUpdate 数据结构
    • 使用 BatchUpdate 成员来更新 logits 处理器的内部状态
    • 注意:批次更新数据结构可能为 None,表示批次组成没有变化。在这种情况下,LogitsProcessor 可能仍希望根据其保留的更新后的 output_token_ids 列表来更新其状态。

vLLM 引擎如何构建 BatchUpdate 数据结构

重要

一些 logits 处理器的设计变更仍在进行中。我们预计未来您无需在实现 logits 处理器时考虑批次状态变更,本节信息将变得无关紧要。

Logits 处理器 update_state() 的实现应假定模型运行器更新持续批次状态的模型如下(此处用 BatchUpdate 抽象表示):

  1. 识别当前引擎步骤中已完成的请求索引

  2. 识别当前步骤中引入的新请求

  3. 使用添加(Add)操作替换尽可能多的已完成请求,按被替换请求的索引从小到大的顺序进行

  4. 基于新请求和已完成请求的相对数量:

    1. 如果新请求和已完成请求的数量相同,则进入下一步

    2. 如果新请求多于已完成请求:应用添加操作,使用未替换已完成请求的剩余新请求扩展批次。为这些新请求分配连续索引,从 current_max_batch_index + 1 开始

    3. 如果新请求少于已完成请求:

      • 对未被新请求替换的已完成请求应用移除(Remove)操作。这些被移除的请求索引必然大于被替换的已完成请求的最大索引。移除操作可能导致批次处于非连续状态

      • “压缩”批次使其连续:从索引最低的空槽(由移除操作导致)开始,应用单向移动(Unidirectional Move),将批次中当前最高的非空槽填入空槽。按空槽目标索引递增和非空槽源索引递减的顺序继续进行额外的单向移动操作,直到批次连续

      • 收缩批次:压缩批次的副作用是,由移除操作产生的空槽被归并在批次数组末尾的一个连续块中。因此,在压缩后,更新 BatchUpdate.batch_size 以反映非空槽的数量

  5. 为了提高效率,对批次进行重排序。根据注意力后端实现和批次的当前特征,可能会应用零次或多次交换移动(Swap Move)操作来重排序批次

注意事项

  • Logits 处理器 update_state() 方法必须按以下顺序处理批次更新操作:移除、添加、移动

  • 添加操作的索引参数是指执行添加时的索引,即在任何移动操作之前

    • 示例:如果请求在索引 5 被添加,然后与索引 3 交换,则 BatchUpdate.added 中的添加操作将与索引 5 相关联,而不是 3
    • 换句话说,可以假定移动操作是在添加和移除之后执行的
  • 可以假定移动操作按其在 BatchUpdate.moved 中出现的顺序执行

  • 如果没有新请求/已完成请求且没有批次重排序,则 logits 处理器的批次更新将为 None

向自定义 Logits 处理器传递自定义参数

与内置的 logits 处理器不同,自定义 logits 处理器可能需要非硬编码在 SamplingParams 或 vLLM 服务器 REST API 中的配置参数。为了解决这个问题,自定义 logits 处理器可以利用 vLLM 自定义参数支持来接收用户的配置设置(尽管您也可以自由设计一个利用 SamplingParams 中已有字段的自定义 logits 处理器。)

自定义 Logits 处理器实现示例

下面的人为示例实现了一个自定义 logits 处理器,它消耗一个 (num\_requests) \times (vocab\_size) 的 logits 张量,并使用 float(-inf) 屏蔽掉除一个 (target_token) 之外的所有 token。该 logits 处理器对于未指定 target_token 的任何请求将被禁用。为了确定 logits 处理器是否启用以及留下哪个 token 不被屏蔽,logits 处理器会检查每个请求关联的 SamplingParams.extra_args 中的 target_token 自定义参数。

自定义 logits 处理器定义示例
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (BatchUpdate,
                                            LogitsProcessor,
                                            MoveDirectionality)

class DummyLogitsProcessor(LogitsProcessor):
    """Fake logit processor to support unit testing and examples"""

    @classmethod
    def validate_params(cls, params: SamplingParams):
        target_token: int | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is not None and not isinstance(target_token, int):
            raise ValueError(f"target_token value {target_token} is not int")

    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                is_pin_memory: bool):
        self.req_info: dict[int, int] = {}

    def is_argmax_invariant(self) -> bool:
        """Never impacts greedy sampling"""
        return False

    def update_state(self, batch_update: BatchUpdate | None):
        if not batch_update:
            return

        # Process added requests.
        for index, params, _, _ in batch_update.added:
            assert params is not None
            self.validate_params(params)
            if params.extra_args and (target_token :=
                                    params.extra_args.get("target_token")):
                self.req_info[index] = target_token
            else: 
                self.req_info.pop(index, None)

        if self.req_info:
            # Process removed requests.
            for index in batch_update.removed:
                self.req_info.pop(index, None)

            # Process moved requests, unidirectional move (a->b) and swap
            # (a<->b)
            for adx, bdx, direct in batch_update.moved:
                a_val = self.req_info.pop(adx, None)
                b_val = self.req_info.pop(bdx, None)
                if a_val is not None:
                    self.req_info[bdx] = a_val
                if direct == MoveDirectionality.SWAP and b_val is not None:
                    self.req_info[adx] = b_val

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.req_info:
            return logits

        # Save target values before modification
        cols = torch.tensor(
            list(self.req_info.values()), dtype=torch.long, device=logits.device
        )
        rows = torch.tensor(
            list(self.req_info.keys()), dtype=torch.long, device=logits.device
        )
        values_to_keep = logits[rows, cols].clone()

        # Mask all but target tokens
        logits[rows] = float('-inf')
        logits[rows, cols] = values_to_keep

        return logits

在本文档的其余部分,我们将使用 DummyLogitsProcessor 作为自定义 logits 处理器的示例。

DummyLogitsProcessor.update_state() 的实现会在 self.req_info 字典中维护批次请求的“稀疏”表示:只有指定了 target_token 值的请求才会出现在字典中。update_state() 会根据针对持续批次的添加、移除和移动操作,调整存储的请求索引和 target_token 值(分别是 self.req_info 中的键和值)。

封装现有的请求级 Logits 处理器

尽管 vLLM 引擎以批次粒度应用 logits 处理器,但有些用户可能希望在 vLLM 中使用“请求级”的 logits 处理器实现——即对单个请求进行操作的实现。如果您的 logits 处理器是为 vLLM 版本 0 开发的(需要它是 Callable,如 此处 所述),这将特别适用,该版本符合以下类型注解:

RequestLogitsProcessor = Union[

    # (output token ids, logits tensor) -> logits tensor
    Callable[[list[int], Tensor], Tensor],

    # (prompt token ids, output token ids, logits tensor) -> logits tensor
    Callable[[list[int], list[int], Tensor], Tensor],
]

虽然 vLLM 引擎明确支持请求级 logits 处理器,但 vLLM 确实提供了一种方便的过程来封装现有的 Callable 请求级 logits 处理器,并创建一个与 vLLM 兼容的批次级 logits 处理器。该 Callable 必须符合上述类型注解;如果您的请求级 logits 处理器具有不同的接口,为了封装它,您可能需要对其进行修改或实现额外的封装层以符合上述接口规范。

您可以通过继承 AdapterLogitsProcessor 来封装请求级 logits 处理器,如下例所示(在此示例中,DummyPerReqLogitsProcessor 是您需要封装的请求级 logits 处理器替代品)。

  • 重写 AdapterLogitsProcessor.validate_params(cls,params) 以验证请求的采样参数。

  • 重写 AdapterLogitsProcessor.is_argmax_invariant(self) 以准确反映您的请求级 logits 处理器是否会影响哪个 token 具有最高值的 logit。

  • 重写 AdapterLogitsProcessor.new_req_logits_processor(self,params) 以从 SamplingParams 实例创建一个新的请求级 logits 处理器实例。

封装请求级 Logits 处理器的示例
...

from vllm.v1.sample.logits_processor import (
    AdapterLogitsProcessor, # Wrapper base-class
    RequestLogitsProcessor, # Request-level logitsproc type annotation
)

...

# Stand-in for your request-level logits processor:
class DummyPerReqLogitsProcessor:
    """The request-level logits processor masks out all logits except the
    token id identified by `target_token`"""

    def __init__(self, target_token: int) -> None:
        """Specify `target_token`"""
        self.target_token = target_token

    def __call__(
        self,
        output_ids: list[int],
        logits: torch.Tensor,
    ) -> torch.Tensor:
        val_to_keep = logits[self.target_token].item()
        logits[:] = float("-inf")
        logits[self.target_token] = val_to_keep
        return logits

...

# Example of wrapping the request-level logits processor:
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
    """Example of wrapping a fake request-level logit processor to create a
    batch-level logits processor"""

    @classmethod
    def validate_params(cls, params: SamplingParams):
        target_token: Any | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is not None and not isinstance(target_token, int):
            raise ValueError(
                f"target_token value {target_token} is not int"
            )

    def is_argmax_invariant(self) -> bool:
        return False

    def new_req_logits_processor(
        self,
        params: SamplingParams,
    ) -> Optional[RequestLogitsProcessor]:
        """This method returns a new request-level logits processor, customized
        to the `target_token` value associated with a particular request.

        Returns None if the logits processor should not be applied to the
        particular request. To use the logits processor the request must have
        a "target_token" custom argument with an integer value.

        Args:
        params: per-request sampling params

        Returns:
        `Callable` request logits processor, or None
        """
        target_token: Any | None = params.extra_args and params.extra_args.get(
            "target_token"
        )
        if target_token is None:
            return None
        return DummyPerReqLogitsProcessor(target_token)

注意

您的 new_req_logits_processor() 重写可以返回 None,以指示不应对相关请求应用封装的 logits 处理器。

一旦您创建了一个自定义子类(例如 WrappedPerReqLogitsProcessor)来封装您的请求级 logits 处理器,您就可以通过下一节中描述的任何方法将该自定义子类传递给 vLLM。

在 vLLM 中加载自定义 Logits 处理器的方法

Logits 处理器在初始化时加载。关键点在于,加载的 logits 处理器集在 vLLM 引擎完成加载后无法修改,并且无法为单个请求按需加载新的 logits 处理器。

本节详细介绍了让 vLLM 识别并触发 vLLM 加载您的 logits 处理器的不同方法。

方法 1:在初始化时将自定义 Logits 处理器的全限定类名 (FQCN) 传递给 vLLM

此方法在离线和在线 vLLM 使用场景中均受支持。自定义 logits 处理器的 FQCN(格式为 dotted.path.to.module:ClassName)可以作为参数传递给 LLMAsyncLLM Python 构造函数,或者作为 CLI 参数传递给 vllm serve(语法如下):

vllm serve ... --logits_processors <logits processor 1> <logits processor 2> ...

对 FQCN 的唯一要求是:

  1. Python 的 importlib.import_module() 必须能够解析 FQCN 的点路径部分并将其加载为模块

  2. 类名部分必须能够从加载的模块中导入

  3. FQCN 指向的对象必须是 LogitsProcessor 的子类

请参阅以下示例

在 Python 中将自定义 logits 处理器 FQCN 传递给 LLM
# Pass in FQCN
llm = LLM(
    model="facebook/opt-125m",
    logits_processors=["your.module.path:DummyLogitsProcessor"],
)
在 Python 中将自定义 logits 处理器 FQCN 传递给 AsyncLLM
# Pass in FQCN
engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                              logits_processors=["your.module.path:DummyLogitsProcessor"])
async_llm = AsyncLLM.from_engine_args(engine_args)
通过 CLI 将自定义 logits 处理器 FQCN 传递给 vLLM 服务器
vllm serve facebook/opt-125m --logits_processors your.module.path:DummyLogitsProcessor

方法 2:自动检测安装在 Python 环境中作为入口点的自定义 Logits 处理器

setuptools 可以使已安装的包通过称为“入口点(entry points)”的元数据片段,将自身作为插件提供给其他 Python 程序。

在初始化期间,vLLM 会自动扫描 vllm.logits_processors 入口点组,并加载其找到的任何已安装的 logits 处理器。

假设您开发了一个包含自定义 logits 处理器的 Python 包。您可以为每个 logits 处理器添加一个唯一的入口点,从而将每个 logits 处理器暴露给 vLLM。下面的示例展示了如何向项目的 pyproject.toml 文件中添加入口点:

将自定义 logits 处理器暴露为 Python 入口点
[project.entry-points."vllm.logits_processors"]
dummy_logits_processor = "your.module.path:DummyLogitsProcessor"

一旦安装了您的包,每当初始化 vLLM 时,您的自定义 logits 处理器都会自动加载。如果您的 logits 处理器以入口点方式暴露,您无需在初始化时显式地将其传递给 LLMAsyncLLM 构造函数或 vLLM 服务器。

注意

vLLM 将始终加载所有在 vllm.logits_processors 分组下通过入口点暴露的 logits 处理器。

方法 3(仅限离线模式):将 Python 类对象传递给 vLLM 构造函数

您可以将一个或多个自定义 logits 处理器类对象传递给 LLMAsyncLLM 构造函数。此选项非常灵活,因为 logits 处理器类既可以 (1) 在实例化 LLMAsyncLLM 的同一个 Python 源文件中本地定义,也可以 (2) 从 Python 包中导入。

在 Python 中将自定义 logits 处理器类对象传递给 LLMAsyncLLM
# Import custom logits processor
from some.module import DummyLogitsProcessor

# ...or...

# Define custom logits processor locally
from vllm.v1.sample.logits_processor import LogitsProcessor

class DummyLogitsProcessor(LogitsProcessor):
    # See DummyLogitsProcessor implementation above
    ...

# Pass class object to LLM constructor
llm = LLM(
    model="facebook/opt-125m",
    logits_processors=[DummyLogitsProcessor],
)

# Pass class object to AsyncLLM constructor
engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                              logits_processors=[DummyLogitsProcessor])
async_llm = AsyncLLM.from_engine_args(engine_args)

针对请求调用自定义 Logits 处理器

自定义 logits 处理器的设计决定了是否必须为给定请求启用/禁用该 logits 处理器,以及必须提供哪些参数来配置它。

下面的示例展示了用户如何向 DummyLogitsProcessor 传递自定义参数 (target_token),以便 (1) 为该特定请求启用 logits 处理器,以及 (2) 控制 logits 处理器的行为。

vLLM REST API:为请求配置自定义 logits 处理器
curl https://:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen2.5-1.5B-Instruct",
        ...
        "vllm_xargs": {"target_token": 67}
    }'
OpenAI SDK:为请求配置自定义 logits 处理器
batch = await client.completions.create(
    model="Qwen/Qwen2.5-1.5B-Instruct",
    ...,
    extra_body={
        "vllm_xargs": {
            "target_token": 67
        }
    }
)
离线模式:为 LLM 请求配置自定义 logits 处理器
outputs_logitproc = llm.generate("your prompt", 
                                 SamplingParams(...,
                                    extra_args={"target_token": 67}))
离线模式:为 AsyncLLM 请求配置自定义 logits 处理器
async for out in engine.generate(request_id="your request id",
                                 prompt="your prompt",
                                 sampling_params=SamplingParams(...,
                                    extra_args={"target_token": 67})):

    # Process async request outputs
    ...

编写自定义 Logits 处理器的最佳实践

一旦 vLLM 在初始化期间加载了 logits 处理器,它就会在每个引擎步骤中针对该处理器调用 update_state()apply()。这两个方法都作用于当前驻留在 vLLM 持续批次中的所有请求。因此,高效地实现这些方法非常重要。

  • 鉴于 logits 处理器是以批次粒度操作的,请编写高效的 apply()update_state() 实现

    • 例如,您可以使用高效的向量化操作来实现 apply(),或在 update_state() 中更新内部状态向量
    • 但是,如果您认为 logits 处理器可能很少被使用,那么使用请求状态的“稀疏”表示可能是合适的,即该类可以使用一个只存储启用了该 logits 处理器的请求元数据的字典来表示请求配置
    • 注意:封装的请求级 logits 处理器无需实现 apply()update_state();默认的 AdapterLogitsProcessor.update_state() 实现维护了请求状态的稀疏表示,其中 new_req_logits_processor() 返回 None 的请求不会在基类状态字典中表示。默认的 AdapterLogitsProcessor.apply() 实现会将请求级 logits 处理器依次应用于每一行输入 logits,并组装输出的 logits 张量。如果 AdapterLogitsProcessor 默认实现的性能不足,请避免封装您的请求级 logits 处理器,而是将其重新实现为 LogitsProcessor 子类,并针对以批次粒度运行的操作实现优化的 apply()update_state()
  • Logits 处理器作者需要自行确定:

    1. 用于配置 logits 处理器针对该请求行为的每请求属性。您的自定义 logits 处理器的 update_state() 重写决定了 SamplingParams 字段如何映射到 logits 处理器状态

      • 注意:对于封装的请求级 logits 处理器,new_req_logits_processor() 决定了 SamplingParams 字段如何用于初始化请求级 logits 处理器实例。
    2. 基于每请求启用或禁用 logits 处理器的条件。除非您的目的是让自定义 logits 处理器一直对所有请求生效,否则您编写的 logits 处理器应能够为给定请求禁用它,例如将参数默认设为 None 或传入特定的无操作参数值(即 0.0)。尽量为禁用 logits 处理器的请求节省计算和内存

      • 注意:对于封装的逐请求 logits 处理器,默认的 AdapterLogitsProcessor.update_state() 实现确保了当 new_req_logits_processor() 为该请求返回 None 时,请求级 logits 处理器会被禁用
    3. 在批次级别短路 logits 处理器的条件。即使您定义了在请求级别禁用自定义 logits 处理器的办法,将其转化为计算节省也可能很困难,例如如果您的 update_state()apply() 实现使用了在单个命令中对整个持续批次操作的高效向量化实现。例如,您不能仅仅因为一个请求禁用了 logits 处理器就跳过 apply() 中的整个向量化操作。为了在没有运行中的请求使用自定义 logits 处理器的边缘情况下节省计算,我们建议将 apply() 设计为在所有请求都禁用了 logits 处理器时返回未经修改的输入张量。同样,请考虑如果没有任何请求启用该 logits 处理器,是否可以在 update_state() 中跳过某些步骤

      • 此外,在 batch_updateNone 时尽早退出是节省 update_state() 计算的简单方法

      • 注意:对于封装的逐请求 logits 处理器,AdapterLogitsProcessor 基类默认实现了上述优化

  • 确保 logits 处理器的 update_state 方法丢弃关于已完成请求的信息(即被添加替换或被移除操作影响的请求)

    • 注意:对于封装的逐请求 logits 处理器,AdapterLogitsProcessor 基类默认处理此情况
  • 如果 logits 处理器行为一致,is_argmax_invariant() 可以硬编码为 TrueFalse。但是,argmax 不变性也可以以编程方式确定(即,如果您的 logits 处理器以某种影响其是否 argmax 不变的方式可由用户自定义)。由于这个原因,is_argmax_invariant() 不是类方法