跳到内容

llmcompressor.modeling.moe_context

MoE 模型校准的简化接口。

MoE(Mixture of Experts,混合专家)模型将 token 路由到不同的专家网络。在进行量化/压缩校准时,我们需要确保所有专家都能看到数据,而不仅仅是路由器选择的那些。本模块提供了临时修改 MoE 模块以进行正确校准的基础设施。

关键组件: - MoECalibrationModule:校准模块的抽象基类 - moe_calibration_context:应用校准到模型的上下文管理器

函数

MoECalibrationModule

继承自:ABC, Module, RegistryMixin

MoE 校准模块的抽象基类。

校准模块在校准阶段替换原始 MoE 模块,以确保所有专家都能接收数据,从而获得正确的量化统计信息。

子类必须: 1. 实现 __init__(),签名如下:(self, original, config, calibrate_all_experts=True) 2. 设置 is_permanent 以指示模块是否应保持校准形式 3. 如果 is_permanent=False,则可选地实现 restore()

方法

  • restore

    恢复原始模块结构。

restore

restore(original: Module) -> torch.nn.Module

恢复原始模块结构。

仅当 is_permanent=False 时需要。对于永久性模块,这是一个 no-op。

返回:原始模块(如果永久则为 self)

源代码在 llmcompressor/modeling/moe_context.py
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
    """
    Restore the original module structure.

    Only needed if is_permanent=False. For permanent modules, this is a no-op.

    Returns:
        The original module (or self if permanent)
    """
    if self.is_permanent:
        return self
    raise NotImplementedError(
        f"{self.__class__.__name__} has is_permanent=False but doesn't "
        "implement restore()"
    )

moe_calibration_context

moe_calibration_context(
    model: PreTrainedModel,
    calibrate_all_experts: bool = True,
)

将 MoE 校准应用于模型的上下文管理器。

此函数扫描模型中的所有模块,并将任何 MoE 模块替换为它们的校准等效模块。上下文退出后,非永久性模块将恢复为原始形式。

模型是就地修改的,因此应在上下文内使用相同的模型对象。

参数: model:要应用 MoE 校准的模型(就地修改) calibrate_all_experts:如果为 True,则在校准期间所有专家都会看到所有 token。如果为 False,则使用正常路由(对某些技术有用)

示例: with moe_calibration_context(model): # 运行校准 - 所有专家都将看到 DataLoader 中批次的数据 for batch in dataloader: model(**batch) # 模型现在已恢复(除非是永久性的)

源代码在 llmcompressor/modeling/moe_context.py
@contextlib.contextmanager
def moe_calibration_context(
    model: PreTrainedModel,
    calibrate_all_experts: bool = True,
):
    """
    Context manager that applies MoE calibration to a model.

    This scans all modules in the model and replaces any MoE modules with their
    calibration equivalents. After the context exits, non-permanent modules are
    restored to their original form.

    The model is modified in-place, so the same model object should be used
    within the context.

    Args:
        model: The model to apply MoE calibration to (modified in-place)
        calibrate_all_experts: If True, all experts see all tokens during calibration.
                               If False, use normal routing (useful for some techniques)

    Example:
        with moe_calibration_context(model):
            # Run calibration - all experts will see data
            for batch in dataloader:
                model(**batch)
        # Model is now restored (unless permanent)
    """

    replaced = {}

    # Step 1: Collect all MoE modules that need replacement
    logger.debug("Entering MoE calibration context")
    modules_to_replace = []
    for name, module in model.named_modules():
        class_name = module.__class__.__name__
        if _is_registered(class_name, MoECalibrationModule):
            modules_to_replace.append((name, module, class_name))

    # Step 2: Replace modules with progress bar
    if modules_to_replace:
        logger.info(f"Found {len(modules_to_replace)} MoE modules to replace")
        for name, module, class_name in tqdm(
            modules_to_replace, desc="Replacing MoE modules for calibration"
        ):
            replacement = MoECalibrationModule.load_from_registry(
                class_name,
                original=module,
                config=model.config,
                calibrate_all_experts=calibrate_all_experts,
            )
            model.set_submodule(name, replacement)
            replaced[name] = (module, replacement)

    # Log what was replaced
    if replaced:
        logger.info(f"Replaced {len(replaced)} MoE modules for calibration")
        permanent_count = sum(
            1 for _, (_, repl) in replaced.items() if repl.is_permanent
        )
        if permanent_count > 0:
            logger.info(
                f"{permanent_count}/{len(replaced)} modules will remain in "
                "calibration form (permanent)"
            )
        if permanent_count < len(replaced):
            logger.info(
                f"{len(replaced) - permanent_count}/{len(replaced)} modules will "
                "be restored after calibration"
            )

    try:
        yield
    finally:
        # Step 2: Restore non-permanent modules
        for name, (original, replacement) in replaced.items():
            if not replacement.is_permanent:
                restored = replacement.restore(original)
                model.set_submodule(name, restored)