跳到内容

Logits Processors

重要

一些 Logits Processors 的设计更改仍在进行中,API 在不久的将来可能会发生变化。我们希望尽快稳定 API 的这部分。

本文档描述了 vLLM 引擎如何与 logits processors 交互,以及 vLLM 支持的用于实现 logits processors 的编程模型。

Logits Processors 背景

Logits processor 会调整下一个 token 的概率分布,通常是为了引导模型产生期望的行为。

在 vLLM 中,logits processors 以批次粒度运行。在给定的引擎步中,logits processor 接收模型输出的 (num_requests) x (vocab_size) 大小的原始 logits 张量。对于所有启用该 logits processor 的请求,logits processor 会对 logits 张量的相应行应用变换,而其他行则保持不变。然后将变换后的 logits 张量传递给 softmax。

vLLM 引擎中的 Logits Processors

vLLM 引擎的持久批次数据结构维护着一个已加载 logits processors 的列表。

为了同时处理整个批次,每个 logits processor 可能会维护关于批次中请求的元数据(即每个请求的特定于 logits processor 的配置设置)。因此,logits processors 是有状态的。

在每个引擎步中,vLLM 引擎将(1)更新每个 logits processor 的内部状态,以及(2)将 logits processors 应用于模型输出的 logits。

更新 Logits Processor 内部状态

在每个引擎步的开始,持久批次可能会根据调度器的输出添加、丢弃和/或重新排序请求。在持久批次重新组织后,vLLM 引擎会调用每个 logits processor 的 update_state() 方法。这是为了确保 logits processor 的内部状态能够与引擎步开始时的新的持久批次状态匹配。

以下伪代码展示了 vLLM 持久批次通知每个 logits processor 批次状态变化的进程。

模型运行器更新 Logits Processor 状态
# gpu_model_runner.py

class GPUModelRunner(...):

    ...

    def execute_model(self, scheduler_output, ...):
        self._update_states(scheduler_output)

        ...

    def _update_states(...):

        ...

        # ...update persistent batch to reflect new/finished requests & reordering
        # of requests within batch...

        ...

        self.input_batch.refresh_metadata()


# gpu_input_batch.py

class InputBatch:

    ...

    def refresh_metadata(self):

        ...

        # Update each logits processor's state to reflect persistent batch state
        batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
        for logit_proc in self.logitsprocs.all:
            logit_proc.update_state(batch_update)

        ...


# vllm/v1/sample/logits_processor/interface.py

@dataclass(frozen=True)
class BatchUpdate:
    # Batch state-change data structure which is passed to logits processors'
    # update_state() methods

    batch_size: int

    removed: Sequence[RemovedRequest]
    added: Sequence[AddedRequest]
    moved: Sequence[MovedRequest]

将 Logits Processors 应用于模型输出 Logits

在更新持久批次状态后,vLLM 模型运行器会执行模型推理以获得 logits。然后,模型运行器会对 logits 调用 sampler。sampler 的一部分操作是针对模型输出 logits 调用 logits processors 的 apply() 方法,得到变换后的 logits(apply() 方法可以就地或非就地修改 logits,但就地修改更节省内存)。此过程如下图所示。

请注意,sampler 将通过 SamplingMetadata.logitsprocs 访问 logits processors。当 vLLM 引擎构建 SamplingMetadata 时(下图未显示),到 logits processors 列表的引用将从持久批次数据结构传递到 SamplingMetadata

将 Logits Processors 应用于模型输出 Logits
# gpu_model_runner.py

class GPUModelRunner(...):

    ...

    def execute_model(self, scheduler_output, ...):
        # (discussed in previous section)
        self._update_states(scheduler_output)

        ...

        # ...run model inference to obtain logits...

        ...

        # Invoke sampler, which applies logits processors
        sampler_output = self.sampler(logits=logits,
                                      sampling_metadata=sampling_metadata)

        ...


# sampler.py

class Sampler(nn.Module):

    ...

    def forward(self, logits, sampling_metadata):

        ...

        # Apply non-argmax-invariant logits processors to model output logits
        for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
            logits = processor.apply(logits)

        sampled = self.sample(logits, sampling_metadata)

        ...

        # ...return sampler output data structure...


    def sample(self, logits, sampling_metadta)

        ...

        # ...exit early if all requests are greedy-sampling...

        ...

        # Apply argmax-invariant logits processors
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)

        ...

        # ...perform sampling and return sampling result...

在采样时,sampler 会检查持久批次中的所有请求是否都使用了贪婪采样。如果是这种情况,sampler 会通过跳过“argmax 不变”的 logits processors 来节省计算。此处,“argmax”是给定 logits 张量行中具有最高 logit 值的 token ID 的简称(即模型为给定请求加权最高的 token)。

  • argmax 不变 Logits Processor 是一个不会改变 argmax 的 logits processor(例如 Min-P)。例如,一个屏蔽低概率 token 的 logits processor 不会改变具有最高 logit 的 token ID。贪婪采样总是选择具有最高 logit 值的 token ID,因此概念上,对于贪婪采样请求,可以跳过 argmax 不变的 logits processor。

  • 非 argmax 不变 Logits Processor 是一个可能改变 argmax 的 logits processor。例如,一个在一定步数后屏蔽除 EOS 之外所有 token 以强制解码终止的 logits processor,可能会屏蔽最高 logit 值的 token,从而改变 argmax。概念上,这些 logits processors 不能为贪婪采样请求而跳过。

vLLM logits processor 抽象要求引擎以批次粒度应用 logits processors;因此,实际上,只有当整个批次都使用贪婪采样时,才能跳过 argmax 不变的 logits processors。

Logits Processor 编程模型

前面的章节暗示了 vLLM logits processors 必须支持的接口。本节将全面介绍用于实现与 vLLM 引擎兼容的 logits processors 的编程模型,包括 LogitsProcessor 基类及其接口方法,以及用于表示持久批次状态变化的 BatchUpdate 数据结构,这两者都显示在下面的代码中。

LogitsProcessor 基类和 BatchUpdate 数据结构
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING

import torch

from vllm import SamplingParams

if TYPE_CHECKING:
    from vllm.config import VllmConfig


class MoveDirectionality(Enum):
    # One-way i1->i2 req move within batch
    UNIDIRECTIONAL = auto()
    # Two-way i1<->i2 req swap within batch
    SWAP = auto()


# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]

# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]

# Batch indices of any removed requests.
RemovedRequest = int


@dataclass(frozen=True)
class BatchUpdate:
    """Persistent batch state change info for logitsprocs"""
    batch_size: int  # Current num reqs in batch

    # Metadata for requests added to, removed from, and moved
    # within the persistent batch.
    #
    # Key assumption: the `output_tok_ids` list (which is an element of each
    # tuple in `added`) is a reference to the request's running output tokens
    # list; via this reference, the logits processors always see the latest
    # list of generated output tokens
    removed: Sequence[RemovedRequest]
    moved: Sequence[MovedRequest]
    added: Sequence[AddedRequest]


class LogitsProcessor(ABC):

    @abstractmethod
    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                is_pin_memory: bool) -> None:
        raise NotImplementedError

    @abstractmethod
    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @abstractmethod
    def is_argmax_invariant(self) -> bool:
        """True if logits processor has no impact on the
        argmax computation in greedy sampling.
        NOTE: may or may not have the same value for all
        instances of a given LogitsProcessor subclass,
        depending on subclass implementation.
        """
        raise NotImplementedError

    @abstractmethod
    def update_state(
        self,
        batch_update: "BatchUpdate" | None,
    ) -> None:
        """Called when there are new output tokens, prior
        to each forward pass.

        Args:
            batch_update is non-None iff there have been
            changes to the batch makeup.
        """
        raise NotImplementedError

    @classmethod
    def validate_params(cls, sampling_params: SamplingParams):
        """Validate sampling params for this logits processor.

        Raise ValueError for invalid ones.
        """
        return None

vLLM logits processor 必须继承 LogitsProcessor 并定义(至少)以下方法:

  • __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)

    • vllm_config:引擎配置数据结构。
    • device:硬件加速器设备信息。
    • is_pin_memory:指示是否可用 pinned memory 来支持 logits processor 实现的标志。
  • apply(self, logits: torch.Tensor) -> torch.Tensor:

    • 接收一个 (num_requests) x (vocab_size) 大小的 logits 张量(logits)。
    • 以批次粒度应用 logits processor 变换。
    • 返回一个变换后的 (num_requests) x (vocab_size) 大小的 logits 张量。
    • 您可以就地或非就地修改输入的 logits processor;就地修改更节省内存。
  • is_argmax_invariant(self) -> bool:

    • 如果 logits processor 是 argmax 不变的(从不改变给定请求的最高 logit 值的 token ID),则返回 True,如果 logits processor 可能会修改 argmax,则返回 False
    • is_argmax_invariant() 在启动时评估一次;如果为 True,vLLM 将在所有请求使用贪婪采样时跳过应用此 logits processor。
  • update_state(self, batch_update: "BatchUpdate" | None) -> None:

    • 接收一个 BatchUpdate 数据结构,表示当前引擎步开始时的持久批次状态变化。
    • 使用 BatchUpdate 成员来更新 logits processor 内部状态。
    • 注意: batch update 数据结构可能为 None,表示批次构成没有变化。在这种情况下,LogitsProcessor 可能仍然希望根据它在添加时可能保留的更新的 output_token_ids 列表来更新其状态。
  • validate_params(cls, sampling_params: SamplingParams):

    • 如果 SamplingParams 包含 logits processor 使用的无效参数(尤其是自定义参数),则引发 ValueError
    • 当请求发送到入口点时,validate_params() 将验证 SamplingParams 并拒绝带有无效参数的请求。

BatchUpdate 数据结构

BatchUpdate 抽象模型将持久批次表示为请求列表,支持以下操作来改变批次状态(请注意,下面操作的顺序反映了它们在 update_state() 中应处理的顺序)。

  • 移除:在索引 i 处移除请求(不替换)。

    • Batchupdate.removed 中的移除操作由一个 int 表示(代表 i)。

    • remove-at-index 对批次的影响。

      Batch: [A,B,C]
      Remove @ i:  1
      
      =>
      
      New Batch: [A,x,C] # Discard B and leave an empty slot
      
  • 添加:在索引 i 处添加(或替换现有请求为)一个新请求。如果替换了请求,其关联的状态应被丢弃。

    • Batchupdate.added 中的添加操作表示为包含以下元素的元组:

      (index, new request SamplingParams, prompt token ids, output token ids)
      
    • prompt token idsoutput token ids 分别是对请求的 prompt token ids 和 output token ids 列表的引用。请注意,output token ids 列表会随着每个引擎步的进行而增长,并且 Logits Processor 可以看到这种增长,因为 output token ids 是按引用传递的。这对于那些考虑了迄今为止生成 token 的 LogitsProcessors 很重要

    • 特定 logits processor 子类的实现决定了如何或是否将添加的请求元组中的字段解析为其内部表示。例如,一个不使用 prompt 或 output token ids 的 logits processor 可能只需要使用 indexSamplingParams,而丢弃其他元组字段。

    • 如果索引 i 当前包含一个请求,则会发生替换。

      Batch: [A,B,C]
      New request to be added @ i: D @ 1
      
      =>
      
      New Batch: [A,D,C] # Add D, discard B
      
    • 如果索引 i 当前不包含请求(因为 i 超出了当前批次大小的范围)。

      Batch: [A,B,C]
      New request to be added @ i: D @ 3
      
      =>
      
      New Batch: [A,B,C,D] # Add D, extending batch
      
  • 移动:将索引 s 的请求移动到索引 d,或者交换索引 sd 的请求。

    • Batchupdate.moved 中的移动操作表示为包含以下元素的元组:

      (s, d, UNIDIRECTIONAL or SWAP)
      
    • 如果移动指定了 UNIDRECTIONAL

      • 索引 s 的请求被移动到索引 d;索引 s 变成一个空槽。

        Batch: [A,x,C,D]
        Unidirectionally Move s -> d:  3 -> 1
        
        =>
        
        New Batch: [A,D,C,x] # Move D to 1, leaving empty slot at 3
        
      • 如果索引 d 已经存在一个请求,它将被替换并丢弃。

        Batch: [A,B,C,D]
        Unidirectionally Move s -> d:  3 -> 1
        
        =>
        
        New Batch: [A,D,C,x] # Move D to 1, discarding B and leaving empty slot at 3
        
    • 如果移动指定了 SWAP,则索引 sd 的请求交换索引。

      Batch: [A,B,C,D]
      Swap Move s <-> d:  3 <-> 1
      
      =>
      
      New Batch: [A,D,C,B] # Swap B and D
      

此外,BatchUpdate 数据结构还包括引擎步开始时持久批次大小的表示(batch_size)。

vLLM 引擎如何构建 BatchUpdate 数据结构

Logits processor update_state() 的实现应假定模型运行器更新持久批次状态的模型如下(此处以 BatchUpdate 抽象来表示):

  1. 识别当前引擎步中完成的请求的索引。

  2. 识别当前步引入的新请求。

  3. 使用 Add 操作,按被替换请求的升序索引(从最小索引开始)将尽可能多的已完成请求替换为新请求。

  4. 基于新请求和已完成请求的数量。

    1. 如果新请求和已完成请求的数量相同,则继续下一步。

    2. 如果新请求多于已完成请求:应用 Add 操作,用剩余未替换已完成请求的新请求扩展批次。为这些新请求分配连续索引,从 current_max_batch_index + 1 开始。

    3. 如果新请求少于已完成请求。

      • 对未被新请求替换的已完成请求应用 Remove 操作。这些移除请求的索引必然大于上一步被替换的已完成请求的最大索引。移除操作可能会使批次处于非连续状态。

      • “压缩”批次使其连续:从最低索引的空槽(由 Remove 操作引起)开始,应用一个单向移动(Unidirectional Move),从当前批次中最高非空槽填充空槽。按空槽目标索引的递增顺序和非空槽源索引的递减顺序进行其他单向移动操作,直到批次连续。

      • 缩小批次:压缩批次的一个副作用是,由 Remove 操作产生的空槽会聚集在批次数组的末尾形成一个连续块。因此,压缩后,更新 BatchUpdate.batch_size 以反映非空槽的数量。

  5. 重新排序批次以提高效率。根据注意力后端实现和当前批次的特性,可能会应用零个或多个 Swap Move 操作来重新排序批次。

注意事项

  • Logits processor update_state() 方法必须按以下顺序处理批次更新操作:移除、添加、移动。

  • Add 操作的索引参数指的是 Add 操作发生时的索引,即在任何 Move 操作之前。

    • 示例:如果一个请求在索引 5 处被添加,然后与索引 3 交换,那么 BatchUpdate.added 中的 Add 操作将与索引 5 相关联,而不是 3。
    • 换句话说,可以假定 Move 操作是在 Add 和 Remove 操作之后应用的。
  • 可以假定 Move 操作是按照它们在 BatchUpdate.moved 中出现的顺序应用的。

  • 如果没有新/已完成请求,也没有批次重新排序,那么 logits processors 的批次更新将是 None

示例:新请求少于完成请求的批次更新

以下示例模拟了一个引擎步,其中引入了 1 个新请求,并移除了 2 个已完成请求,此外,注意力后端执行了交换以优化批次排序。

Batch state (beginning of engine step): [A,B,C,D]
Batch size: 4

New requests: E

Finished requests: A, C

Processing steps (using BatchUpdate abstraction):

1. Add E at index 0

[E,B,C,D] # Discard A
Batch size: 4

2. Remove at index 2

[E,B,x,D] # Discard C, empty slot at index 2
Batch size: 4

3. Condense batch with a Unidirectional Move 3 -> 2 operation and shrink batch

[E,B,D] x # Empty slot is now outside batch
Batch size: 3

4. Attention backend optimization: reorder batch with Swap 0 <-> 1

[B,E,D]
Batch size: 3

生成的 BatchUpdate 数据结构将如下所示:

BatchUpdate instance
* added: [(0,E's SamplingParams,E's prompt tokens ref,E's output tokens ref)]
* removed: [2] # request C was removed without replacement
* moved: [(3,2,UNIDIRECTIONAL),(0,1,SWAP)]

示例:新请求多于完成请求的批次更新

以下示例模拟了一个引擎步,其中引入了 2 个新请求,并移除了 1 个已完成请求,此外,注意力后端执行了交换以优化批次排序。

Batch state (beginning of engine step): [A,B,C,D]
Batch size: 4

New requests: E,F

Finished requests: C

Processing steps (using BatchUpdate abstraction):

1. Add E at index 2

[A,B,E,D] # Discard C
Batch size: 4

2. Add F at index 4 (current max batch index + 1)

[A,B,E,D,F] # Extend batch by 1
Batch size: 5

4. Attention backend optimization: reorder batch with Swap 0 <-> 1

[B,A,E,D,F]
Batch size: 5

请注意,由于 Remove 操作没有留下空槽,因此跳过了批次压缩。

生成的 BatchUpdate 数据结构将如下所示:

BatchUpdate instance
* added: [(2,E's SamplingParams,E's prompt tokens ref,E's output tokens ref),(4,F's SamplingParams,F's prompt tokens ref,F's output tokens ref)]
* removed: [] # no requests were removed without replacement
* moved: [(0,1,SWAP)]

如何向 vLLM 引入新的 Logits Processor

编写内置 Logits Processors 的最佳实践

  • 考虑到 logits processors 以批次粒度运行,请编写高效的 apply()update_state() 实现。

    • 例如,您可能可以使用高效的向量化操作来实现 apply() 或在 update_state() 中更新内部状态向量。
    • 但是,如果您认为某个 logits processor 可能不常使用,则可以使用“稀疏”的请求状态表示,即该类可以使用字典来表示请求配置,该字典仅存储启用 logits processor 的请求的元数据。
  • 这取决于 logits processor 的作者来决定:

    1. 配置 logits processor 对该请求行为的每个请求属性。 例如,如果您正在为 vLLM 编写一个新的内置 logits processor,您可能需要向 SamplingParams 和 vLLM REST API 添加额外的字段,也可能不需要。

    2. Logits processor 在每个请求的基础上启用或不启用的条件。 除非您的目的是让内置 logits processor 始终对所有请求起作用,否则您应该编写您的 logits processor,使其能够为特定请求禁用 logits processor,例如通过将参数默认设置为 None 或传递特定的无操作参数值,即 0.0。请尝试为禁用 logits processor 的请求节省计算和内存。

    3. Logits processor 在批次级别短路的条件。 即使您已定义了在请求级别禁用内置 logits processor 的方法,也很难将其转化为计算节省,例如,如果您的 update_state()apply() 实现使用了在整个持久批次上执行一次的向量化实现。例如,即使一个请求禁用了 logits processor,您也不能仅凭此跳过 apply() 中的整个向量化操作。为了在没有运行请求使用内置 logits processor 的边缘情况下节省计算,我们建议将 apply() 设计为在所有请求都禁用 logits processor 时返回未修改的输入张量。同样,考虑在没有请求启用 logits processor 的情况下是否可以跳过 update_state() 中的步骤。

      • 此外,在 update_state() 中节省计算的一种简单方法是在 batch_updateNone 时提前退出。
  • 确保 logits processor update_state 方法丢弃已完成请求(即被 Add 操作替换或受 Remove 操作影响的请求)的信息。

  • 如果 logits processor 具有一致的行为,is_argmax_invariant() 可以硬编码为 TrueFalse。但是,argmax 的不变性也可以通过编程方式确定(例如,如果您的 logits processor 是用户可自定义的,并且以某种方式影响了 logits processor 是否是 argmax 不变的)。因此,is_argmax_invariant() 不是一个类方法。

内置 Logits Processors

内置 logits processors 在 vLLM 引擎启动时始终会加载。请参阅 vllm/v1/sample/logits_processor/builtin.py 中的现有 vLLM 内置 logits processors,以获取关于如何编写新的内置 vLLM logits processor 的示例。如果某个 logits processor 可能对广大用户有用,那么将其作为内置处理器引入是合理的。vLLM 目前根据上述编程模型使用以下内置 logits processors:

  • Min-P

  • Logit 偏差

  • Min-tokens

请参考这些 logits processor 的实现,以获得编写内置 logits processors 的指导。

此外,以下类似 logits processor 的功能已被硬编码到 sampler 中,尚未利用上述编程模型。其中大部分将被重构为使用上述 logits processor 编程模型。

  • 允许的 token IDs

  • 不良词汇

  • 重复惩罚

  • 频率惩罚

  • 存在惩罚

  • 温度

  • Top-K

  • Top-P

自定义 Logits Processors

vLLM 可以通过用户提供的自定义 logits processors 进行扩展。