flatten_for_calibration(
value: Tensor,
base_name: str,
args: QuantizationArgs,
g_idx: Optional[Tensor] = None,
) -> torch.Tensor
根据量化策略重塑值,用于缩放/零点校准。展平后的值具有以下形状:
(num_observations, *qparam_shape, group_size)
第一个维度是观测次数(通常是批次大小乘以令牌数),中间的维度是缩放器的维度,最后一个维度是每组要量化的元素数。
参数
-
值
(Tensor) – -
base_name
(str) – 权重、输入、输出、q/k/v。用于将值表征为权重、激活或注意力状态
-
args
(QuantizationArgs) – -
g_idx
(Optional[Tensor], 默认值: None ) –
返回
源代码在 llmcompressor/observers/helpers.py
| def flatten_for_calibration(
value: torch.Tensor,
base_name: str,
args: QuantizationArgs,
g_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Reshapes the value according to the quantization strategy for the purposes of
scale/zp calibration. The value after flattening has the following shape:
`(num_observations, *qparam_shape, group_size)`
The first dim is the number of observations (usually the batch size times number of
tokens), the middle dims are the dimension of the scales, and the last dim is the
number of elements being quantized per group.
:param value: value being flattened
:param base_name: weight, input, output, q/k/v. Used to characterize the value as
being a weight, activation, or attention state
:param args: quantization args for determining how the value is flattened
:param g_idx: optional gidx for weight activation ordering
:return: value which has been reshaped for calibration
"""
if base_name == "weight":
return _flatten_weight(value, args, g_idx)
elif base_name in ("input", "output"):
return _flatten_activation(value, args)
elif base_name in ("q", "k", "v"):
return _flatten_attention(value, args)
else:
raise ValueError(f"Unknown quantization base name: {base_name}")
|