Logits 处理器¶
重要
部分 logits 处理器设计变更仍在进行中,API 在不久的将来可能会发生变化。我们希望尽快稳定这部分 API。
本文档介绍了 vLLM 引擎如何与 logits 处理器交互,以及 vLLM 支持的用于实现 logits 处理器的编程模型。
Logits 处理器背景¶
logits 处理器用于调整下一个 token 的概率分布,通常目的是引导模型实现期望的行为。
在 vLLM 中,logits 处理器以批处理(batch)粒度运行。在给定的引擎步骤中,logits 处理器会消耗模型输出的 (num_requests) x (vocab_size) 原始 logits 张量。对于所有启用了 logits 处理器的请求,处理器会对相应的 logits 张量行进行转换,而保持其他行不变。转换后的 logits 张量随后被送入 softmax。
vLLM 引擎中的 Logits 处理器¶
vLLM 引擎的持久化批处理数据结构维护了一个已加载的 logits 处理器列表。
为了能够一次性处理整个批次,每个 logits 处理器可能会维护有关批中请求的元数据(即每个请求特定的 logits 处理器配置设置)。因此,logits 处理器是有状态的。
在每个引擎步骤中,vLLM 引擎将 (1) 更新每个 logits 处理器的内部状态,以及 (2) 将 logits 处理器应用于模型输出的 logits。
更新 Logits 处理器内部状态¶
在每个引擎步骤开始时,持久化批处理会根据调度器的输出添加、丢弃和/或重新排序请求。在持久化批处理重组后,vLLM 引擎会调用每个 logits 处理器的 update_state() 方法。这对于确保 logits 处理器的内部状态在引擎步骤开始时与新的持久化批处理状态相匹配是必要的。
下方的伪代码展示了 vLLM 持久化批处理通知每个 logits 处理器批状态变更的过程
模型运行器更新 Logits 处理器状态
# gpu_model_runner.py
class GPUModelRunner(...):
...
def execute_model(self, scheduler_output, ...):
self._update_states(scheduler_output)
...
def _update_states(...):
...
# ...update persistent batch to reflect new/finished requests & reordering
# of requests within batch...
...
self.input_batch.refresh_metadata()
# gpu_input_batch.py
class InputBatch:
...
def refresh_metadata(self):
...
# Update each logits processor's state to reflect persistent batch state
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
...
# vllm/v1/sample/logits_processor/interface.py
@dataclass(frozen=True)
class BatchUpdate:
# Batch state-change data structure which is passed to logits processors'
# update_state() methods
batch_size: int
removed: Sequence[RemovedRequest]
added: Sequence[AddedRequest]
moved: Sequence[MovedRequest]
将 Logits 处理器应用于模型输出 Logits¶
更新持久化批处理状态后,vLLM 模型运行器执行模型推理以获取 logits。然后,模型运行器针对这些 logits 调用采样器(sampler)。采样器操作的一部分是针对模型输出的 logits 调用 logits 处理器的 apply() 方法,从而产生转换后的 logits(apply() 方法可以就地或非就地修改 logits,尽管就地修改更节省内存)。此过程见下方伪代码。
请注意,采样器将通过 SamplingMetadata.logitsprocs 访问 logits 处理器。当 vLLM 引擎构建 SamplingMetadata(未在下方代码中显示)时,对 logits 处理器列表的引用会从持久化批处理数据结构传递到 SamplingMetadata。
将 logits 处理器应用于模型输出的 logits
# gpu_model_runner.py
class GPUModelRunner(...):
...
def execute_model(self, scheduler_output, ...):
# (discussed in previous section)
self._update_states(scheduler_output)
...
# ...run model inference to obtain logits...
...
# Invoke sampler, which applies logits processors
sampler_output = self.sampler(logits=logits,
sampling_metadata=sampling_metadata)
...
# sampler.py
class Sampler(nn.Module):
...
def forward(self, logits, sampling_metadata):
...
# Apply non-argmax-invariant logits processors to model output logits
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
logits = processor.apply(logits)
sampled = self.sample(logits, sampling_metadata)
...
# ...return sampler output data structure...
def sample(self, logits, sampling_metadata)
...
# ...exit early if all requests are greedy-sampling...
...
# Apply argmax-invariant logits processors
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
...
# ...perform sampling and return sampling result...
在采样时,采样器会检查持久化批处理中的所有请求是否都使用贪婪采样(greedy sampling)。如果是这种情况,采样器可以通过跳过“argmax 不变”的 logits 处理器来节省计算资源。此处,“argmax”是给定 logits 张量行中 logit 值最高的 token ID 的缩写(即模型为给定请求加权最高的 token)。
-
argmax 不变 logits 处理器是指一种不修改 argmax 的 logits 处理器(例如 Min-P)。例如,屏蔽掉概率最低的 token 的 logits 处理器不会改变具有最大 logit 值的 token ID。贪婪采样总是选择 logit 值最高的 token ID,因此从概念上讲,对于贪婪采样请求,可以跳过 argmax 不变的 logits 处理器。
-
非 argmax 不变 logits 处理器是指可能修改 argmax 的 logits 处理器。例如,为了强制终止解码,在一定步骤后屏蔽除 EOS 之外的所有 token 的 logits 处理器,可能会最终屏蔽掉最大 logit 值的 token,从而改变 argmax。从概念上讲,这些 logits 处理器不能被贪婪采样请求跳过。
vLLM 的 logits 处理器抽象要求引擎以批处理粒度应用 logits 处理器;因此,实际上只有在整个批处理都使用贪婪采样时,才能跳过 argmax 不变的 logits 处理器。
Logits 处理器编程模型¶
前几节提到了 vLLM logits 处理器必须支持的接口。本节完整介绍了实现与 vLLM 引擎兼容的 logits 处理器的编程模型,包括 LogitsProcessor 基类及其接口方法,以及用于表示持久化批状态变更的 BatchUpdate 数据结构,这两者均显示在下面的代码中
LogitsProcessor 基类和 BatchUpdate 数据结构
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING
import torch
from vllm import SamplingParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = auto()
# Two-way i1<->i2 req swap within batch
SWAP = auto()
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list of generated output tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class LogitsProcessor(ABC):
@abstractmethod
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool) -> None:
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: "BatchUpdate" | None,
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
vLLM logits 处理器必须继承 LogitsProcessor 并定义(至少)以下方法
-
__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)vllm_config: 引擎配置数据结构device: 硬件加速器设备信息is_pin_memory: 指示是否可以使用锁页内存(pin memory)来支持 logits 处理器实现的标志
-
apply(self, logits: torch.Tensor) -> torch.Tensor:- 消耗一个
(num_requests) x (vocab_size)的 logits 张量 (logits) - 以批处理粒度应用 logits 处理器转换
- 返回转换后的
(num_requests) x (vocab_size)logits 张量 - 您可以就地或非就地修改输入 logits;就地修改更节省内存
- 消耗一个
-
is_argmax_invariant(self) -> bool:- 如果 logits 处理器是 argmax 不变的(永远不会改变给定请求中具有最高 logit 值的 token ID),则返回
True;如果可能修改 argmax,则返回False is_argmax_invariant()在启动时仅评估一次;如果为True,当所有请求都使用贪婪采样时,vLLM 将在给定步骤中跳过应用此 logits 处理器
- 如果 logits 处理器是 argmax 不变的(永远不会改变给定请求中具有最高 logit 值的 token ID),则返回
-
update_state(self, batch_update: "BatchUpdate" | None) -> None:- 消耗一个表示当前引擎步骤开始时持久化批处理状态变更的
BatchUpdate数据结构 - 使用
BatchUpdate成员更新 logits 处理器内部状态 - 注意:批量更新数据结构可能为
None,表示批处理组成没有变化。在这种情况下,LogitsProcessor 可能仍希望根据其添加时保留的已更新output_token_ids列表来更新其状态。
- 消耗一个表示当前引擎步骤开始时持久化批处理状态变更的
-
validate_params(cls, sampling_params: SamplingParams):- 如果
SamplingParams具有 logits 处理器使用的无效参数(尤其是自定义参数),则引发ValueError。 - 当请求发送到入口点时,
validate_params()将验证SamplingParams并拒绝具有无效参数的请求。
- 如果
BatchUpdate 数据结构¶
BatchUpdate 抽象将持久化批处理建模为请求列表,支持以下操作来更改批处理状态(请注意,下面提到的操作顺序反映了它们在 update_state() 中应处理的顺序)
-
移除 (Remove):移除索引为
i的请求(不进行替换)-
移除操作在
Batchupdate.removed中由int(表示i)表示 -
按索引移除对批处理的影响
-
-
添加 (Add):在索引
i处添加(或替换现有)新请求。如果请求被替换,其关联状态应被丢弃。-
添加操作在
Batchupdate.added中表示为包含以下内容的元组 -
prompt token ids和output token ids分别是对请求的 prompt token id 和 output token id 列表的引用。请注意,output token id 列表随着每个引擎步骤而增长,且此增长对 logits 处理器是可见的,因为 output token id 是按引用传递的。这对于考虑迄今为止已生成 token 的 LogitsProcessor 来说非常重要。 -
特定的 logits 处理器子类的实现决定了添加请求元组中的字段如何被消化为内部表示。例如,不使用 prompt 或 output token id 的 logits 处理器可能只需要使用
index和SamplingParams,而丢弃其他元组字段 -
如果索引
i当前持有一个请求,则会发生替换 -
如果索引
i当前未持有一个请求(因为i超出了当前批处理大小的范围)
-
-
移动 (Move):将索引
s处的请求移动到索引d,或者交换索引s和d处的请求-
移动操作在
Batchupdate.moved中表示为包含以下内容的元组 -
如果移动指定为
UNIDIRECTIONAL(单向)-
索引
s处的请求被移动到索引d;索引s变成空槽 -
如果索引
d处已经存在另一个请求,它将被替换并丢弃
-
-
如果移动指定为
SWAP(交换),索引s和d处的请求交换索引
-
此外,BatchUpdate 数据结构包含引擎步骤开始时持久化批处理大小的表示(batch_size)。
vLLM 引擎如何构建 BatchUpdate 数据结构¶
Logits 处理器 update_state() 的实现应假定模型运行器更新持久化批处理状态的以下模型(此处以 BatchUpdate 抽象表示)
-
识别在当前引擎步骤中完成的请求的索引
-
识别当前步骤中引入的新请求
-
使用“添加”操作替换尽可能多的已完成请求,按被替换请求的索引从小到大顺序进行
-
根据新请求和完成请求的相对数量
-
如果新请求和完成请求的数量相同,则进入下一步
-
如果新请求多于完成请求:应用“添加”操作,用未替换已完成请求的剩余新请求来扩展批处理。为这些新请求分配连续的索引,从
current_max_batch_index + 1开始 -
如果新请求少于完成请求
-
对未被新请求替换的已完成请求应用“移除”操作。这些已移除请求的索引必然大于前一步骤中被替换的已完成请求的最大索引。移除操作可能会使批处理处于非连续状态
-
“压缩”批处理使其连续:从最低索引的空槽(由移除操作引起)开始,应用单向移动,将批处理中当前最高的非空槽填入该空槽。按照空槽目标索引递增和非空槽源索引递减的顺序,继续进行额外的单向移动操作,直到批处理连续
-
缩小批处理:压缩批处理的一个副作用是,由于移除操作产生的空槽被聚集在批处理数组末尾的一个连续块中。因此,在压缩后,更新
BatchUpdate.batch_size以反映非空槽的数量
-
-
-
重新排序批处理以提高效率。根据注意力后端实现和批处理的当前特征,可能会应用零个或多个交换移动操作来重新排序批处理
注意事项
-
Logits 处理器
update_state()方法必须按以下顺序处理批处理更新操作:移除、添加、移动 -
“添加”操作的索引参数是指发生添加时的索引,即在任何移动操作之前
- 示例:如果一个请求在索引 5 处被添加,然后与索引 3 交换,则
BatchUpdate.added中的添加操作将与索引 5 相关联,而不是索引 3 - 换句话说,移动操作可以假定在添加和移除之后执行
- 示例:如果一个请求在索引 5 处被添加,然后与索引 3 交换,则
-
移动操作可以假定按照它们在
BatchUpdate.moved中出现的顺序执行 -
如果没有新/完成请求且没有批处理重排序,则 logits 处理器的批处理更新将为
None
示例:新请求少于完成请求时的批量更新¶
以下示例模拟了一个引擎步骤,其中引入了 1 个新请求并消除了 2 个已完成请求,此外注意力后端执行了交换以优化批处理顺序。
Batch state (beginning of engine step): [A,B,C,D]
Batch size: 4
New requests: E
Finished requests: A, C
Processing steps (using BatchUpdate abstraction):
1. Add E at index 0
[E,B,C,D] # Discard A
Batch size: 4
2. Remove at index 2
[E,B,x,D] # Discard C, empty slot at index 2
Batch size: 4
3. Condense batch with a Unidirectional Move 3 -> 2 operation and shrink batch
[E,B,D] x # Empty slot is now outside batch
Batch size: 3
4. Attention backend optimization: reorder batch with Swap 0 <-> 1
[B,E,D]
Batch size: 3
生成的 BatchUpdate 数据结构将如下所示
BatchUpdate instance
* added: [(0,E's SamplingParams,E's prompt tokens ref,E's output tokens ref)]
* removed: [2] # request C was removed without replacement
* moved: [(3,2,UNIDIRECTIONAL),(0,1,SWAP)]
示例:新请求多于完成请求时的批量更新¶
以下示例模拟了一个引擎步骤,其中引入了 2 个新请求并消除了 1 个已完成请求,此外注意力后端执行了交换以优化批处理顺序。
Batch state (beginning of engine step): [A,B,C,D]
Batch size: 4
New requests: E,F
Finished requests: C
Processing steps (using BatchUpdate abstraction):
1. Add E at index 2
[A,B,E,D] # Discard C
Batch size: 4
2. Add F at index 4 (current max batch index + 1)
[A,B,E,D,F] # Extend batch by 1
Batch size: 5
4. Attention backend optimization: reorder batch with Swap 0 <-> 1
[B,A,E,D,F]
Batch size: 5
请注意,由于移除操作没有留下空槽,因此跳过了批处理压缩。
生成的 BatchUpdate 数据结构将如下所示
BatchUpdate instance
* added: [(2,E's SamplingParams,E's prompt tokens ref,E's output tokens ref),(4,F's SamplingParams,F's prompt tokens ref,F's output tokens ref)]
* removed: [] # no requests were removed without replacement
* moved: [(0,1,SWAP)]
如何向 vLLM 引入新的 Logits 处理器¶
编写内置 Logits 处理器的最佳实践¶
-
鉴于 logits 处理器以批处理粒度运行,请编写高效的
apply()和update_state()实现- 例如,您可以使用高效的向量化操作来实现
apply(),或在update_state()中更新内部状态向量 - 但是,如果您认为 logits 处理器可能很少使用,那么使用请求状态的“稀疏”表示可能是合适的,即该类可以使用仅存储启用该 logits 处理器的请求元数据的字典来表示请求配置
- 例如,您可以使用高效的向量化操作来实现
-
由 logits 处理器作者决定
-
配置 logits 处理器对该请求行为的按请求属性。 例如,如果您正在为 vLLM 编写新的内置 logits 处理器,您可能需要也可能不需要向
SamplingParams和 vLLM REST API 添加其他字段 -
在按请求基础上启用或禁用 logits 处理器的条件。 除非您打算让内置 logits 处理器始终作用于所有请求,否则您应该以一种可以为特定请求禁用 logits 处理器的方式编写处理器,例如通过将参数默认为
None或传入特定的“无操作”参数值(例如0.0)。尝试为禁用该 logits 处理器的请求节省计算和内存资源 -
在批处理级别短路 logits 处理器的条件。 即使您已经定义了一种在请求级别禁用内置 logits 处理器的方法,也很难将其转化为计算节省,即如果您的
update_state()和apply()实现使用在单个命令中操作整个持久化批处理的高效向量化实现。例如,您不能仅仅因为一个请求禁用了 logits 处理器就跳过apply()中的整个向量化操作。为了在没有运行请求使用内置 logits 处理器这种极端情况下节省计算,我们建议设计apply(),使其在所有请求都禁用该 logits 处理器时返回未修改的输入张量。同样,如果没有任何请求启用 logits 处理器,请考虑是否可以跳过update_state()中的步骤- 此外,在
update_state()中节省计算的一种简单方法是在 batch_update 为None时尽早退出
- 此外,在
-
-
确保 logits 处理器
update_state方法丢弃有关已完成请求的信息(即被“添加”替换或受到“移除”的请求) -
如果 logits 处理器具有一致的行为,可以将
is_argmax_invariant()硬编码为True或False。但是,argmax 不变性也可以通过编程确定(例如,如果您的 logits 处理器以某种影响它是否为 argmax 不变的方式实现用户自定义)。因此,is_argmax_invariant()不是一个类方法
内置 Logits 处理器¶
vLLM 引擎启动时始终会加载内置 logits 处理器。查看 vllm/v1/sample/logits_processor/builtin.py 中现有的 vLLM 内置 logits 处理器,获取如何编写新的内置 vLLM logits 处理器的示例。如果新的 logits 处理器可能对广大受众有用,那么提交 PR 将其作为内置处理器引入是有意义的。vLLM 目前基于上述编程模型采用以下内置 logits 处理器
-
Min-P
-
Logit 偏置 (Logit bias)
-
最小 token 数 (Min-tokens)
请查阅这些 logits 处理器实现,以获取有关编写内置 logits 处理器的指导。
此外,以下类似 logits 处理器的功能被硬编码到采样器中,尚未使用上述编程模型。它们中的大多数将被重构以使用上述 logits 处理器编程模型。
-
允许的 token ID
-
坏词屏蔽 (Bad words)
-
重复惩罚 (Repetition penalty)
-
频率惩罚 (Frequency penalty)
-
存在惩罚 (Presence penalty)
-
温度 (Temperature)
-
Top-K
-
Top-P
自定义 Logits 处理器¶
可以通过 用户提供的自定义 logits 处理器 来增强 vLLM。