多Token预测 (MTP)#
为什么我们需要 MTP#
MTP通过并行预测多个Token来提升推理性能,将生成方式从单Token转向多Token生成。这种方法显著提高了生成吞吐量,实现了推理速度的倍数级加速,同时不牺牲输出质量。
如何使用 MTP#
要为DeepSeek-V3模型启用MTP,请在启动服务时添加以下参数:
–speculative_config ‘ {“method”: “mtp”, “num_speculative_tokens”: 1, “disable_padded_drafter_batch”: False} ‘
num_speculative_tokens: 如果提供,则为模型一次预测多个Token的数量。如果草稿模型配置中存在,则默认为该值;否则,此参数是必需的。disable_padded_drafter_batch: 禁用用于推测解码的输入填充。如果设置为True,推测输入批次可以包含不同长度的序列,这可能仅被某些注意力后端支持。目前这仅影响推测的MTP方法,默认为False。
工作原理#
模块架构#
vllm_ascend
├── sample
│ ├── rejection_sample.py
├── spec_decode
│ ├── mtp_proposer.py
└───────────
1. 采样
rejection_sample.py: 在解码过程中,主模型将前一轮的输出Token与预测的Token一起处理(同时计算1+k个Token)。第一个Token总是正确的,而第二个Token—称为奖励Token—是不确定的,因为它来自推测性预测。因此,我们采用贪婪策略和拒绝采样策略来决定是否接受奖励Token。模块结构包含一个
AscendRejectionSampler类,其forward方法实现了具体的采样逻辑。
rejection_sample.py
├── AscendRejectionSampler
│ ├── forward
2. spec_decode
此部分包括用于spec-decode的模型预处理,主要结构如下:包括加载模型、执行一次模拟运行以及生成Token ID。这些步骤共同构成了单个spec-decode操作的模型数据构建和前向调用。
mtp_proposer.py: 配置vLLM-Ascend使用推测性解码,其中草案由deepseek mtp层生成。
mtp_proposer.py
├── Proposer
│ ├── load_model
│ ├── dummy_run
│ ├── generate_token_ids
│ ├── _prepare_inputs
│ ├── _propose
算法#
1. 拒绝采样
贪婪策略
验证主模型生成的Token是否与MTP在前一轮预测的推测Token匹配。如果完全匹配,则接受奖励Token;否则,拒绝该Token及其后续基于该推测的所有Token。
拒绝采样策略
该方法在拒绝采样中引入了随机性。
对于每个草案Token,通过验证不等式P_target / P_draft ≥ U是否成立来确定是否接受,其中P_target表示目标模型为当前草案Token分配的概率,P_draft表示草案模型分配的概率,而U是从区间[0, 1)均匀采样的随机数。
每个草案Token的决策逻辑如下:如果不等式P_target / P_draft ≥ U成立,则接受该草案Token作为输出;反之,如果P_target / P_draft < U,则拒绝该草案Token。
当草案Token被拒绝时,将触发一个恢复采样过程,从调整后的概率分布Q = max(P_target - P_draft, 0)中重新采样一个“恢复Token”。在当前的MTP实现中,由于P_draft未提供且默认为1,公式简化为当P_target ≥ U时接受Token,恢复分布变为Q = max(P_target - 1, 0)。
2. 性能
如果奖励Token被接受,MTP模型将进行(num_speculative +1)个Token的推理,包括原始主模型输出Token和奖励Token。如果被拒绝,则根据接受的Token数量进行较少Token的推理。
DFX#
方法验证#
目前,spec_decode场景仅支持ngram、eagle、eagle3和mtp等方法。如果为方法传递了不正确的参数,代码将引发错误以警告用户提供了错误的方法。
def get_spec_decode_method(method,
vllm_config,
device,
runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'mtp':
return MtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
整数验证#
当前的npu_fused_infer_attention_score算子每解码轮次仅支持小于16的整数。因此,MTP的最大支持值为15。如果提供了大于15的值,代码将引发错误并警告用户。
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
限制#
由于DeepSeek的MTP仅暴露单层权重,当MTP > 1(尤其是MTP ≥ 3)的情况下,精度和性能未得到有效保证。此外,由于当前的算子限制,MTP最多支持15。
在fullgraph模式下,当MTP > 1时,每个aclgraph的捕获大小必须是(num_speculative_tokens + 1)的整数倍。