跳到内容

llmcompressor.modifiers.utils.pytorch_helpers

PyTorch 特定的模型压缩辅助函数。

提供 PyTorch 模型操作的实用函数,包括批处理、填充掩码应用和模型架构检测。支持 MoE(专家混合)模型和用于压缩工作流的专用张量操作。

函数

apply_pad_mask_to_batch

apply_pad_mask_to_batch(
    batch: Dict[str, Tensor],
) -> Dict[str, torch.Tensor]

将掩码应用于批次的输入 ID。这用于将填充标记清零,以免它们对 GPTQ 和 SparseGPT 算法中的 Hessian 计算产生贡献。

假设 attention_mask 只包含零和一。

参数

  • batch

    (Dict[str, Tensor]) –

    如果存在,则应用于填充的批次。

返回

  • Dict[str, Tensor]

    input_ids 中填充已清零的批次。

源文件位于 llmcompressor/modifiers/utils/pytorch_helpers.py
def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Apply a mask to the input ids of a batch. This is used to zero out
    padding tokens so they do not contribute to the hessian calculation in the
    GPTQ and SparseGPT algorithms

    Assumes that `attention_mask` only contains zeros and ones

    :param batch: batch to apply padding to if it exists
    :return: batch with padding zeroed out in the input_ids
    """
    if "attention_mask" in batch:
        for key in ("input_ids", "decoder_input_ids"):
            if key in batch:
                batch[key] = batch[key] * batch["attention_mask"]

    return batch

is_moe_model

is_moe_model(model: Module) -> bool

检查模型是否为专家混合模型

参数

  • model

    (Module) –

    要检查的模型

返回

  • bool

    如果模型是专家混合模型,则为 True。

源文件位于 llmcompressor/modifiers/utils/pytorch_helpers.py
def is_moe_model(model: Module) -> bool:
    """
    Check if the model is a mixture of experts model

    :param model: the model to check
    :return: True if the model is a mixture of experts model
    """

    # Check for MoE components
    for _, module in model.named_modules():
        module_name = module.__class__.__name__
        if "MoE" in module_name or "Expert" in module_name:
            return True

    # Check config for MoE attributes
    if hasattr(model, "config"):
        if any(
            "moe" in attr.lower() or "expert" in attr.lower()
            for attr in dir(model.config)
        ):
            return True

    return False