跳到内容

llmcompressor.observers.helpers

用于观察者令牌计数和分析的辅助函数。

提供用于分析模型模块中观察者统计信息和令牌计数的实用函数。用于监控压缩效果以及理解量化和剪枝操作期间的模型行为。

函数

flatten_for_calibration

flatten_for_calibration(
    value: Tensor,
    base_name: str,
    args: QuantizationArgs,
    g_idx: Optional[Tensor] = None,
) -> torch.Tensor

根据量化策略重塑值,用于缩放/零点校准。展平后的值具有以下形状:

(num_observations, *qparam_shape, group_size)

第一个维度是观测次数(通常是批次大小乘以令牌数),中间的维度是缩放器的维度,最后一个维度是每组要量化的元素数。

参数

  • (Tensor) –

    要展平的值

  • base_name

    (str) –

    权重、输入、输出、q/k/v。用于将值表征为权重、激活或注意力状态

  • args

    (QuantizationArgs) –

    用于确定值如何展平的量化参数

  • g_idx

    (Optional[Tensor], 默认值: None ) –

    用于权重激活排序的可选 gidx

返回

  • Tensor

    已为校准重塑的值

源代码在 llmcompressor/observers/helpers.py
def flatten_for_calibration(
    value: torch.Tensor,
    base_name: str,
    args: QuantizationArgs,
    g_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Reshapes the value according to the quantization strategy for the purposes of
    scale/zp calibration. The value after flattening has the following shape:

    `(num_observations, *qparam_shape, group_size)`

    The first dim is the number of observations (usually the batch size times number of
    tokens), the middle dims are the dimension of the scales, and the last dim is the
    number of elements being quantized per group.

    :param value: value being flattened
    :param base_name: weight, input, output, q/k/v. Used to characterize the value as
        being a weight, activation, or attention state
    :param args: quantization args for determining how the value is flattened
    :param g_idx: optional gidx for weight activation ordering
    :return: value which has been reshaped for calibration
    """
    if base_name == "weight":
        return _flatten_weight(value, args, g_idx)
    elif base_name in ("input", "output"):
        return _flatten_activation(value, args)
    elif base_name in ("q", "k", "v"):
        return _flatten_attention(value, args)
    else:
        raise ValueError(f"Unknown quantization base name: {base_name}")