跳到内容

llmcompressor.modifiers.pruning.wanda

模块

WandaPruningModifier

基类:SparsityModifierBase

用于对模型应用一次性 WANDA 算法的修饰符,详见论文:https://arxiv.org/abs/2306.11695

示例 yaml

test_stage:
  sparsity_modifiers:
    WandaPruningModifier:
      sparsity: 0.5
      mask_structure: "2:4"

生命周期

  • on_initialize
    • register_hook(module, calibrate_module, "forward")
    • run_sequential / run_basic
      • make_empty_row_scalars
      • accumulate_row_scalars
  • on_sequential_batch_end
    • sparsify_weight
  • on_finalize
    • remove_hooks()

参数

  • 稀疏度

    模型压缩到的稀疏度

  • 稀疏度配置文件

    可以设置为“owl”以使用离群值加权分层稀疏度(OWL),更多信息可在论文 https://arxiv.org/pdf/2310.05175 中找到

  • 掩码结构

    定义要应用的掩码结构的字符串。必须是 N:M 形式,其中 N、M 是定义自定义块形状的整数。默认为 0:0,表示非结构化掩码。

  • owl_m

    OWL 中使用的离群值数量

  • owl_lmbda

    OWL 中使用的 Lambda 值

  • sequential_targets

    OBCQ 期间要压缩的层名称列表,或 'ALL' 表示压缩模型中的所有层。targets 的别名

  • targets

    OBCQ 期间要压缩的层名称列表,或 'ALL' 表示压缩模型中的所有层。sequential_targets 的别名

  • ignore

    可选的模块类名称或子模块名称列表,即使它们与目标匹配,也不进行量化。默认为空列表。

方法

calibrate_module

calibrate_module(
    module: Module,
    args: Tuple[Tensor, ...],
    _output: Tensor,
)

用于累积模块输入行标量的校准钩子

参数

  • module

    (Module) –

    正在校准的模块

  • args

    (Tuple[Tensor, ...]) –

    模块的输入,其中第一个元素是规范输入

  • _输出

    (Tensor) –

    未压缩的模块输出,未使用

源代码位于 llmcompressor/modifiers/pruning/wanda/base.py
def calibrate_module(
    self,
    module: torch.nn.Module,
    args: Tuple[torch.Tensor, ...],
    _output: torch.Tensor,
):
    """
    Calibration hook used to accumulate the row scalars of the input to the module

    :param module: module being calibrated
    :param args: inputs to the module, the first element of which is the
        canonical input
    :param _output: uncompressed module output, unused
    """
    # Assume that the first argument is the input
    inp = args[0]

    # Initialize row scalars if not present
    if module not in self._num_samples:
        device = get_execution_device(module)
        self._row_scalars[module] = make_empty_row_scalars(module, device=device)
        self._num_samples[module] = 0

    # Accumulate scalars using data
    self._row_scalars[module], self._num_samples[module] = accumulate_row_scalars(
        inp,
        module,
        self._row_scalars[module],
        self._num_samples[module],
    )

compress_modules

compress_modules()

稀疏化已校准的模块

源代码位于 llmcompressor/modifiers/pruning/wanda/base.py
def compress_modules(self):
    """
    Sparsify modules which have been calibrated
    """
    for module in list(self._num_samples.keys()):
        name = self._module_names[module]
        sparsity = self._module_sparsities[module]
        num_samples = self._num_samples[module]

        logger.info(f"Sparsifying {name} using {num_samples} samples")
        with torch.no_grad(), align_module_device(module), CompressionLogger(
            module
        ):
            sparsified_weight = sparsify_weight(
                module=module,
                row_scalars_dict=self._row_scalars,
                sparsity=sparsity,
                prune_n=self._prune_n,
                prune_m=self._prune_m,
            )

        update_offload_parameter(module, "weight", sparsified_weight)

        # self._row_scalars[module] already deleted by sparsify_weight
        del self._num_samples[module]