跳到内容

分类用法

分类涉及预测给定输入最符合的预定义类别、类或标签。

摘要

  • 模型用法:(序列) 分类
  • 池化任务:classify
  • 离线 API
    • LLM.classify(...)
    • LLM.encode(..., pooling_task="classify")
  • 在线 API

(序列)分类和令牌(token)分类之间的关键区别在于输出粒度:(序列)分类为整个输入序列生成单个结果,而令牌分类则为序列中的每个单独令牌生成结果。

许多分类模型同时支持(序列)分类和令牌分类。有关令牌分类的更多详细信息,请参阅此页面

只有当分类模型的 num_labels 等于 1 时,它才能用作评分模型并启用评分 API,请参阅此页面

典型用例

分类

分类模型最基础的应用是将输入数据分类到预定义的类别中。

支持的模型

纯文本模型

架构 模型 示例 HF 模型 LoRA PP
ErnieForSequenceClassification 基于 BERT 的中文 ERNIE Forrest20231206/ernie-3.0-base-zh-cls
GPT2ForSequenceClassification GPT2 nie3e/sentiment-polish-gpt2-small
Qwen2ForSequenceClassificationC 基于 Qwen2 jason9693/Qwen2.5-1.5B-apeach
*ModelC, *ForCausalLMC 生成式模型 不适用 * *

多模态模型

注意

有关多模态模型输入的更多信息,请参阅此页面

架构 模型 输入 示例 HF 模型 LoRA PP
Qwen2_5_VLForSequenceClassificationC 基于 Qwen2_5_VL T + IE+ + VE+ muziyongshixin/Qwen2.5-VL-7B-for-VideoCls
*ForConditionalGenerationC, *ForCausalLMC 生成式模型 * 不适用 * *

C 通过 --convert classify 自动转换为分类模型。(详情)
* 功能支持与原始模型相同。

如果您的模型不在上述列表中,我们将尝试使用 as_seq_cls_model 自动转换该模型。默认情况下,类概率是从对应于最后一个令牌的 softmax 后的隐藏状态中提取的。

交叉编码器模型

交叉编码器(也称为重排序器,reranker)模型是分类模型的一个子集,它接受两个提示作为输入并输出 num_labels 等于 1 的结果。大多数分类模型也可以用作交叉编码器模型。有关交叉编码器模型的更多信息,请参阅此页面

纯文本模型

架构 模型 示例 HF 模型 评分模板(见说明) LoRA PP
BertForSequenceClassification 基于 BERT cross-encoder/ms-marco-MiniLM-L-6-v2 不适用
GemmaForSequenceClassification 基于 Gemma BAAI/bge-reranker-v2-gemma(见说明)等 bge-reranker-v2-gemma.jinja ✅︎ ✅︎
GteNewForSequenceClassification mGTE-TRM(见说明) Alibaba-NLP/gte-multilingual-reranker-base 不适用
LlamaBidirectionalForSequenceClassificationC 基于 Llama,带双向注意力 nvidia/llama-nemotron-rerank-1b-v2 nemotron-rerank.jinja ✅︎ ✅︎
Qwen2ForSequenceClassificationC 基于 Qwen2 mixedbread-ai/mxbai-rerank-base-v2(见说明)等 mxbai_rerank_v2.jinja ✅︎ ✅︎
Qwen3ForSequenceClassificationC 基于 Qwen3 tomaarsen/Qwen3-Reranker-0.6B-seq-cls, Qwen/Qwen3-Reranker-0.6B(见说明)等 qwen3_reranker.jinja ✅︎ ✅︎
RobertaForSequenceClassification 基于 RoBERTa cross-encoder/quora-roberta-base 不适用
XLMRobertaForSequenceClassification 基于 XLM-RoBERTa BAAI/bge-reranker-v2-m3 不适用
*ModelC, *ForCausalLMC 生成式模型 不适用 不适用 * *

C 通过 --convert classify 自动转换为分类模型。(详情)
* 功能支持与原始模型相同。

注意

某些模型需要特定的提示词格式才能正常工作。

您可以在以下位置找到示例 HF 模型的相应评分模板: examples/pooling/score/template/

示例: examples/pooling/score/using_template_offline.py examples/pooling/score/using_template_online.py

注意

使用以下命令加载官方原始的 BAAI/bge-reranker-v2-gemma

vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'

注意

第二代 GTE 模型 (mGTE-TRM) 名为 NewForSequenceClassification。由于 NewForSequenceClassification 这个名称过于通用,您应该设置 --hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}' 以明确指定使用 GteNewForSequenceClassification 架构。

注意

使用以下命令加载官方原始的 mxbai-rerank-v2

vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}'

注意

使用以下命令加载官方原始的 Qwen3 Reranker。更多信息请查看: examples/pooling/score/qwen3_reranker_offline.py examples/pooling/score/qwen3_reranker_online.py

vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'

多模态模型

注意

有关多模态模型输入的更多信息,请参阅此页面

架构 模型 输入 示例 HF 模型 LoRA PP
JinaVLForSequenceClassification 基于 JinaVL T + IE+ jinaai/jina-reranker-m0 ✅︎ ✅︎
LlamaNemotronVLForSequenceClassification Llama Nemotron Reranker + SigLIP T + IE+ nvidia/llama-nemotron-rerank-vl-1b-v2
Qwen3VLForSequenceClassification Qwen3-VL-Reranker T + IE+ + VE+ Qwen/Qwen3-VL-Reranker-2B(见说明)等 ✅︎ ✅︎

C 通过 --convert classify 自动转换为分类模型。(详情)
* 功能支持与原始模型相同。

注意

与 Qwen3-Reranker 类似,您需要使用以下 --hf_overrides 来加载官方原始的 Qwen3-VL-Reranker

vllm serve Qwen/Qwen3-VL-Reranker-2B --hf_overrides '{"architectures": ["Qwen3VLForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'

奖励模型

将 (序列) 分类模型用作奖励模型。欲了解更多信息,请参阅奖励模型

架构 模型 示例 HF 模型 LoRA PP
JambaForSequenceClassification Jamba ai21labs/Jamba-tiny-reward-dev ✅︎ ✅︎
Qwen3ForSequenceClassificationC 基于 Qwen3 Skywork/Skywork-Reward-V2-Qwen3-0.6B ✅︎ ✅︎
LlamaForSequenceClassificationC 基于 Llama Skywork/Skywork-Reward-V2-Llama-3.2-1B ✅︎ ✅︎
*ModelC, *ForCausalLMC 生成式模型 不适用 * *

C 通过 --convert classify 自动转换为分类模型。(详情)

如果您的模型不在上述列表中,我们将尝试使用 as_seq_cls_model 自动转换该模型。默认情况下,类概率是从对应于最后一个令牌的 softmax 后的隐藏状态中提取的。

离线推理

池化参数

支持以下池化参数

    use_activation: bool | None = None

LLM.classify

classify 方法为每个提示输出一个概率向量。

from vllm import LLM

llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
(output,) = llm.classify("Hello, my name is")

probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")

代码示例见此: examples/basic/offline_inference/classify.py

LLM.encode

encode 方法适用于 vLLM 中的所有池化模型。

在使用 LLM.encode 进行分类模型推理时,请设置 pooling_task="classify"

from vllm import LLM

llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
(output,) = llm.encode("Hello, my name is", pooling_task="classify")

data = output.outputs.data
print(f"Data: {data!r}")

在线服务

分类 API

在线 /classify API 与 LLM.classify 类似。

补全参数

支持以下分类 API 参数

代码
    model: str | None = None
    user: str | None = None
    input: list[int] | list[list[int]] | str | list[str]

支持以下额外参数

代码
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    truncation_side: Literal["left", "right"] | None = Field(
        default=None,
        description=(
            "Which side to truncate from when truncate_prompt_tokens is active. "
            "'right' keeps the first N tokens. "
            "'left' keeps the last N tokens."
        ),
    )
    request_id: str = Field(
        default_factory=random_uuid,
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    priority: int = Field(
        default=0,
        ge=-(2**63),
        le=2**63 - 1,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    mm_processor_kwargs: dict[str, Any] | None = Field(
        default=None,
        description="Additional kwargs to pass to the HF processor.",
    )
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."
        ),
    )
    use_activation: bool | None = Field(
        default=None,
        description="Whether to use activation for the pooler outputs. "
        "`None` uses the pooler's default, which is `True` in most cases.",
    )

对话参数

对于类似对话的输入(即如果传入了 messages),支持以下参数

代码
    model: str | None = None
    user: str | None = None
    messages: list[ChatCompletionMessageParam]

转而支持这些额外参数

代码
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    truncation_side: Literal["left", "right"] | None = Field(
        default=None,
        description=(
            "Which side to truncate from when truncate_prompt_tokens is active. "
            "'right' keeps the first N tokens. "
            "'left' keeps the last N tokens."
        ),
    )
    request_id: str = Field(
        default_factory=random_uuid,
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    priority: int = Field(
        default=0,
        ge=-(2**63),
        le=2**63 - 1,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    mm_processor_kwargs: dict[str, Any] | None = Field(
        default=None,
        description="Additional kwargs to pass to the HF processor.",
    )
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    add_generation_prompt: bool = Field(
        default=False,
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
    )
    continue_final_message: bool = Field(
        default=False,
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."
        ),
    )
    chat_template: str | None = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."
        ),
    )
    chat_template_kwargs: dict[str, Any] | None = Field(
        default=None,
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."
        ),
    )
    media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
        default=None,
        description=(
            "Additional kwargs to pass to the media IO connectors, "
            "keyed by modality. Merged with engine-level media_io_kwargs."
        ),
    )
    use_activation: bool | None = Field(
        default=None,
        description="Whether to use activation for the pooler outputs. "
        "`None` uses the pooler's default, which is `True` in most cases.",
    )

请求示例

代码示例: examples/pooling/classify/classification_online.py

您可以通过传递字符串数组来分类多个文本

curl -v "http://127.0.0.1:8000/classify" \
  -H "Content-Type: application/json" \
  -d '{
    "model": "jason9693/Qwen2.5-1.5B-apeach",
    "input": [
      "Loved the new café—coffee was great.",
      "This update broke everything. Frustrating."
    ]
  }'
响应
{
  "id": "classify-7c87cac407b749a6935d8c7ce2a8fba2",
  "object": "list",
  "created": 1745383065,
  "model": "jason9693/Qwen2.5-1.5B-apeach",
  "data": [
    {
      "index": 0,
      "label": "Default",
      "probs": [
        0.565970778465271,
        0.4340292513370514
      ],
      "num_classes": 2
    },
    {
      "index": 1,
      "label": "Spoiled",
      "probs": [
        0.26448777318000793,
        0.7355121970176697
      ],
      "num_classes": 2
    }
  ],
  "usage": {
    "prompt_tokens": 20,
    "total_tokens": 20,
    "completion_tokens": 0,
    "prompt_tokens_details": null
  }
}

您也可以直接向 input 字段传递字符串

curl -v "http://127.0.0.1:8000/classify" \
  -H "Content-Type: application/json" \
  -d '{
    "model": "jason9693/Qwen2.5-1.5B-apeach",
    "input": "Loved the new café—coffee was great."
  }'
响应
{
  "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
  "object": "list",
  "created": 1745383213,
  "model": "jason9693/Qwen2.5-1.5B-apeach",
  "data": [
    {
      "index": 0,
      "label": "Default",
      "probs": [
        0.565970778465271,
        0.4340292513370514
      ],
      "num_classes": 2
    }
  ],
  "usage": {
    "prompt_tokens": 10,
    "total_tokens": 10,
    "completion_tokens": 0,
    "prompt_tokens_details": null
  }
}

更多示例

更多示例见此: examples/pooling/classify

支持的功能

启用/禁用激活

您可以通过 use_activation 来启用或禁用激活。

问题类型 (如 multi_label_classification)

您可以通过 Hugging Face 配置中的 problem_type 修改 problem_type。支持的问题类型有:single_label_classificationmulti_label_classificationregression

实现了与 transformers ForSequenceClassificationLoss 的对齐。

Logit 偏置

您可以通过 vllm.config.PoolerConfig 中的 logit_bias 参数来修改 logit_bias(也称为 sigmoid_normalize)。

已移除的功能

从 PoolingParams 中移除 softmax

我们已经从 PoolingParams 中移除了 softmaxactivation。请改用 use_activation,因为我们允许 classifytoken_classify 使用任何激活函数。