多Token预测 (MTP)#

为什么我们需要 MTP#

MTP通过并行预测多个Token来提升推理性能,将生成方式从单Token转向多Token生成。这种方法显著提高了生成吞吐量,实现了推理速度的倍数级加速,同时不牺牲输出质量。

如何使用 MTP#

要为DeepSeek-V3模型启用MTP,请在启动服务时添加以下参数:

–speculative_config ‘ {“method”: “mtp”, “num_speculative_tokens”: 1, “disable_padded_drafter_batch”: False} ‘

  • num_speculative_tokens: 如果提供,则为模型一次预测多个Token的数量。如果草稿模型配置中存在,则默认为该值;否则,此参数是必需的。

  • disable_padded_drafter_batch: 禁用用于推测解码的输入填充。如果设置为True,推测输入批次可以包含不同长度的序列,这可能仅被某些注意力后端支持。目前这仅影响推测的MTP方法,默认为False。

工作原理#

模块架构#

vllm_ascend
├── sample
│   ├── rejection_sample.py
├── spec_decode
│   ├── mtp_proposer.py
└───────────

1. 采样

  • rejection_sample.py: 在解码过程中,主模型将前一轮的输出Token与预测的Token一起处理(同时计算1+k个Token)。第一个Token总是正确的,而第二个Token—称为奖励Token—是不确定的,因为它来自推测性预测。因此,我们采用贪婪策略拒绝采样策略来决定是否接受奖励Token。模块结构包含一个AscendRejectionSampler类,其forward方法实现了具体的采样逻辑。

rejection_sample.py
├── AscendRejectionSampler
│   ├── forward

2. spec_decode

此部分包括用于spec-decode的模型预处理,主要结构如下:包括加载模型、执行一次模拟运行以及生成Token ID。这些步骤共同构成了单个spec-decode操作的模型数据构建和前向调用。

  • mtp_proposer.py: 配置vLLM-Ascend使用推测性解码,其中草案由deepseek mtp层生成。

mtp_proposer.py
├── Proposer
│   ├── load_model
│   ├── dummy_run
│   ├── generate_token_ids
│   ├── _prepare_inputs
│   ├── _propose

算法#

1. 拒绝采样

  • 贪婪策略

验证主模型生成的Token是否与MTP在前一轮预测的推测Token匹配。如果完全匹配,则接受奖励Token;否则,拒绝该Token及其后续基于该推测的所有Token。

  • 拒绝采样策略

该方法在拒绝采样中引入了随机性。

对于每个草案Token,通过验证不等式P_target / P_draft U是否成立来确定是否接受,其中P_target表示目标模型为当前草案Token分配的概率,P_draft表示草案模型分配的概率,而U是从区间[0, 1)均匀采样的随机数。

每个草案Token的决策逻辑如下:如果不等式P_target / P_draft U成立,则接受该草案Token作为输出;反之,如果P_target / P_draft < U,则拒绝该草案Token。

当草案Token被拒绝时,将触发一个恢复采样过程,从调整后的概率分布Q = max(P_target - P_draft, 0)中重新采样一个“恢复Token”。在当前的MTP实现中,由于P_draft未提供且默认为1,公式简化为当P_target U时接受Token,恢复分布变为Q = max(P_target - 1, 0)

2. 性能

如果奖励Token被接受,MTP模型将进行(num_speculative +1)个Token的推理,包括原始主模型输出Token和奖励Token。如果被拒绝,则根据接受的Token数量进行较少Token的推理。

DFX#

方法验证#

  • 目前,spec_decode场景仅支持ngram、eagle、eagle3和mtp等方法。如果为方法传递了不正确的参数,代码将引发错误以警告用户提供了错误的方法。

def get_spec_decode_method(method,
                           vllm_config,
                           device,
                           runner):
    if method == "ngram":
        return NgramProposer(vllm_config, device, runner)
    elif method in ["eagle", "eagle3"]:
        return EagleProposer(vllm_config, device, runner)
    elif method == 'mtp':
        return MtpProposer(vllm_config, device, runner)
    else:
        raise ValueError("Unknown speculative decoding method: "
                         f"{method}")

整数验证#

  • 当前的npu_fused_infer_attention_score算子每解码轮次仅支持小于16的整数。因此,MTP的最大支持值为15。如果提供了大于15的值,代码将引发错误并警告用户。

if self.speculative_config:
    spec_token_num = self.speculative_config.num_speculative_tokens
    self.decode_threshold += spec_token_num
    assert self.decode_threshold <= 16, f"decode_threshold exceeded \
        npu_fused_infer_attention_score TND layout's limit of 16, \
        got {self.decode_threshold}"

限制#

  • 由于DeepSeek的MTP仅暴露单层权重,当MTP > 1(尤其是MTP ≥ 3)的情况下,精度和性能未得到有效保证。此外,由于当前的算子限制,MTP最多支持15。

  • 在fullgraph模式下,当MTP > 1时,每个aclgraph的捕获大小必须是(num_speculative_tokens + 1)的整数倍。