分类用法¶
分类涉及预测给定输入最符合的预定义类别、类或标签。
摘要¶
- 模型用法:(序列) 分类
- 池化任务:
classify - 离线 API
LLM.classify(...)LLM.encode(..., pooling_task="classify")
- 在线 API
- 分类 API (
/classify) - 池化 API (
/pooling)
- 分类 API (
(序列)分类和令牌(token)分类之间的关键区别在于输出粒度:(序列)分类为整个输入序列生成单个结果,而令牌分类则为序列中的每个单独令牌生成结果。
许多分类模型同时支持(序列)分类和令牌分类。有关令牌分类的更多详细信息,请参阅此页面。
只有当分类模型的 num_labels 等于 1 时,它才能用作评分模型并启用评分 API,请参阅此页面。
典型用例¶
分类¶
分类模型最基础的应用是将输入数据分类到预定义的类别中。
支持的模型¶
纯文本模型¶
| 架构 | 模型 | 示例 HF 模型 | LoRA | PP |
|---|---|---|---|---|
ErnieForSequenceClassification | 基于 BERT 的中文 ERNIE | Forrest20231206/ernie-3.0-base-zh-cls | ||
GPT2ForSequenceClassification | GPT2 | nie3e/sentiment-polish-gpt2-small | ||
Qwen2ForSequenceClassificationC | 基于 Qwen2 | jason9693/Qwen2.5-1.5B-apeach | ||
*ModelC, *ForCausalLMC 等 | 生成式模型 | 不适用 | * | * |
多模态模型¶
注意
有关多模态模型输入的更多信息,请参阅此页面。
| 架构 | 模型 | 输入 | 示例 HF 模型 | LoRA | PP |
|---|---|---|---|---|---|
Qwen2_5_VLForSequenceClassificationC | 基于 Qwen2_5_VL | T + IE+ + VE+ | muziyongshixin/Qwen2.5-VL-7B-for-VideoCls | ||
*ForConditionalGenerationC, *ForCausalLMC 等 | 生成式模型 | * | 不适用 | * | * |
C 通过 --convert classify 自动转换为分类模型。(详情)
* 功能支持与原始模型相同。
如果您的模型不在上述列表中,我们将尝试使用 as_seq_cls_model 自动转换该模型。默认情况下,类概率是从对应于最后一个令牌的 softmax 后的隐藏状态中提取的。
交叉编码器模型¶
交叉编码器(也称为重排序器,reranker)模型是分类模型的一个子集,它接受两个提示作为输入并输出 num_labels 等于 1 的结果。大多数分类模型也可以用作交叉编码器模型。有关交叉编码器模型的更多信息,请参阅此页面。
纯文本模型¶
| 架构 | 模型 | 示例 HF 模型 | 评分模板(见说明) | LoRA | PP |
|---|---|---|---|---|---|
BertForSequenceClassification | 基于 BERT | cross-encoder/ms-marco-MiniLM-L-6-v2 等 | 不适用 | ||
GemmaForSequenceClassification | 基于 Gemma | BAAI/bge-reranker-v2-gemma(见说明)等 | bge-reranker-v2-gemma.jinja | ✅︎ | ✅︎ |
GteNewForSequenceClassification | mGTE-TRM(见说明) | Alibaba-NLP/gte-multilingual-reranker-base 等 | 不适用 | ||
LlamaBidirectionalForSequenceClassificationC | 基于 Llama,带双向注意力 | nvidia/llama-nemotron-rerank-1b-v2 等 | nemotron-rerank.jinja | ✅︎ | ✅︎ |
Qwen2ForSequenceClassificationC | 基于 Qwen2 | mixedbread-ai/mxbai-rerank-base-v2(见说明)等 | mxbai_rerank_v2.jinja | ✅︎ | ✅︎ |
Qwen3ForSequenceClassificationC | 基于 Qwen3 | tomaarsen/Qwen3-Reranker-0.6B-seq-cls, Qwen/Qwen3-Reranker-0.6B(见说明)等 | qwen3_reranker.jinja | ✅︎ | ✅︎ |
RobertaForSequenceClassification | 基于 RoBERTa | cross-encoder/quora-roberta-base 等 | 不适用 | ||
XLMRobertaForSequenceClassification | 基于 XLM-RoBERTa | BAAI/bge-reranker-v2-m3 等 | 不适用 | ||
*ModelC, *ForCausalLMC 等 | 生成式模型 | 不适用 | 不适用 | * | * |
C 通过 --convert classify 自动转换为分类模型。(详情)
* 功能支持与原始模型相同。
注意
某些模型需要特定的提示词格式才能正常工作。
您可以在以下位置找到示例 HF 模型的相应评分模板: examples/pooling/score/template/
示例: examples/pooling/score/using_template_offline.py examples/pooling/score/using_template_online.py
注意
使用以下命令加载官方原始的 BAAI/bge-reranker-v2-gemma。
注意
第二代 GTE 模型 (mGTE-TRM) 名为 NewForSequenceClassification。由于 NewForSequenceClassification 这个名称过于通用,您应该设置 --hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}' 以明确指定使用 GteNewForSequenceClassification 架构。
注意
使用以下命令加载官方原始的 mxbai-rerank-v2。
注意
使用以下命令加载官方原始的 Qwen3 Reranker。更多信息请查看: examples/pooling/score/qwen3_reranker_offline.py examples/pooling/score/qwen3_reranker_online.py。
多模态模型¶
注意
有关多模态模型输入的更多信息,请参阅此页面。
| 架构 | 模型 | 输入 | 示例 HF 模型 | LoRA | PP |
|---|---|---|---|---|---|
JinaVLForSequenceClassification | 基于 JinaVL | T + IE+ | jinaai/jina-reranker-m0 等 | ✅︎ | ✅︎ |
LlamaNemotronVLForSequenceClassification | Llama Nemotron Reranker + SigLIP | T + IE+ | nvidia/llama-nemotron-rerank-vl-1b-v2 | ||
Qwen3VLForSequenceClassification | Qwen3-VL-Reranker | T + IE+ + VE+ | Qwen/Qwen3-VL-Reranker-2B(见说明)等 | ✅︎ | ✅︎ |
C 通过 --convert classify 自动转换为分类模型。(详情)
* 功能支持与原始模型相同。
注意
与 Qwen3-Reranker 类似,您需要使用以下 --hf_overrides 来加载官方原始的 Qwen3-VL-Reranker。
奖励模型¶
将 (序列) 分类模型用作奖励模型。欲了解更多信息,请参阅奖励模型。
| 架构 | 模型 | 示例 HF 模型 | LoRA | PP |
|---|---|---|---|---|
JambaForSequenceClassification | Jamba | ai21labs/Jamba-tiny-reward-dev 等 | ✅︎ | ✅︎ |
Qwen3ForSequenceClassificationC | 基于 Qwen3 | Skywork/Skywork-Reward-V2-Qwen3-0.6B 等 | ✅︎ | ✅︎ |
LlamaForSequenceClassificationC | 基于 Llama | Skywork/Skywork-Reward-V2-Llama-3.2-1B 等 | ✅︎ | ✅︎ |
*ModelC, *ForCausalLMC 等 | 生成式模型 | 不适用 | * | * |
C 通过 --convert classify 自动转换为分类模型。(详情)
如果您的模型不在上述列表中,我们将尝试使用 as_seq_cls_model 自动转换该模型。默认情况下,类概率是从对应于最后一个令牌的 softmax 后的隐藏状态中提取的。
离线推理¶
池化参数¶
支持以下池化参数。
LLM.classify¶
classify 方法为每个提示输出一个概率向量。
from vllm import LLM
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
(output,) = llm.classify("Hello, my name is")
probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
代码示例见此: examples/basic/offline_inference/classify.py
LLM.encode¶
encode 方法适用于 vLLM 中的所有池化模型。
在使用 LLM.encode 进行分类模型推理时,请设置 pooling_task="classify"
from vllm import LLM
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
(output,) = llm.encode("Hello, my name is", pooling_task="classify")
data = output.outputs.data
print(f"Data: {data!r}")
在线服务¶
分类 API¶
在线 /classify API 与 LLM.classify 类似。
补全参数¶
支持以下分类 API 参数
代码
支持以下额外参数
代码
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
truncation_side: Literal["left", "right"] | None = Field(
default=None,
description=(
"Which side to truncate from when truncate_prompt_tokens is active. "
"'right' keeps the first N tokens. "
"'left' keeps the last N tokens."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
priority: int = Field(
default=0,
ge=-(2**63),
le=2**63 - 1,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description="Additional kwargs to pass to the HF processor.",
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for the pooler outputs. "
"`None` uses the pooler's default, which is `True` in most cases.",
)
对话参数¶
对于类似对话的输入(即如果传入了 messages),支持以下参数
转而支持这些额外参数
代码
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
truncation_side: Literal["left", "right"] | None = Field(
default=None,
description=(
"Which side to truncate from when truncate_prompt_tokens is active. "
"'right' keeps the first N tokens. "
"'left' keeps the last N tokens."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
priority: int = Field(
default=0,
ge=-(2**63),
le=2**63 - 1,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description="Additional kwargs to pass to the HF processor.",
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for the pooler outputs. "
"`None` uses the pooler's default, which is `True` in most cases.",
)
请求示例¶
代码示例: examples/pooling/classify/classification_online.py
您可以通过传递字符串数组来分类多个文本
curl -v "http://127.0.0.1:8000/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": [
"Loved the new café—coffee was great.",
"This update broke everything. Frustrating."
]
}'
响应
{
"id": "classify-7c87cac407b749a6935d8c7ce2a8fba2",
"object": "list",
"created": 1745383065,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [
0.565970778465271,
0.4340292513370514
],
"num_classes": 2
},
{
"index": 1,
"label": "Spoiled",
"probs": [
0.26448777318000793,
0.7355121970176697
],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 20,
"total_tokens": 20,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
您也可以直接向 input 字段传递字符串
curl -v "http://127.0.0.1:8000/classify" \
-H "Content-Type: application/json" \
-d '{
"model": "jason9693/Qwen2.5-1.5B-apeach",
"input": "Loved the new café—coffee was great."
}'
响应
{
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
"object": "list",
"created": 1745383213,
"model": "jason9693/Qwen2.5-1.5B-apeach",
"data": [
{
"index": 0,
"label": "Default",
"probs": [
0.565970778465271,
0.4340292513370514
],
"num_classes": 2
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 10,
"completion_tokens": 0,
"prompt_tokens_details": null
}
}
更多示例¶
更多示例见此: examples/pooling/classify
支持的功能¶
启用/禁用激活¶
您可以通过 use_activation 来启用或禁用激活。
问题类型 (如 multi_label_classification)¶
您可以通过 Hugging Face 配置中的 problem_type 修改 problem_type。支持的问题类型有:single_label_classification、multi_label_classification 和 regression。
实现了与 transformers ForSequenceClassificationLoss 的对齐。
Logit 偏置¶
您可以通过 vllm.config.PoolerConfig 中的 logit_bias 参数来修改 logit_bias(也称为 sigmoid_normalize)。
已移除的功能¶
从 PoolingParams 中移除 softmax¶
我们已经从 PoolingParams 中移除了 softmax 和 activation。请改用 use_activation,因为我们允许 classify 和 token_classify 使用任何激活函数。