跳到内容

llmcompressor.observers.moving_base

MovingAverageObserverBase

MovingAverageObserverBase(
    base_name: str,
    args: QuantizationArgs,
    module: Optional[Module] = None,
    **observer_kwargs,
)

基类:Observer

通过对最小/最大值进行移动平均来计算量化参数

参数

  • base_name

    (str) –

    用于命名观察器属性的字符串

  • args

    (QuantizationArgs) –

    用于校准和量化观测值的量化参数

  • module

    (Optional[Module], 默认值: None ) –

    可选模块,附带量化参数。此参数是利用现有 qparams(例如 global_scale 或 g_idx)所必需的

  • **observer_kwargs

    观察器初始化关键字参数

方法

源代码位于 llmcompressor/observers/moving_base.py
def __init__(
    self,
    base_name: str,
    args: QuantizationArgs,
    module: Optional[torch.nn.Module] = None,
    **observer_kwargs,
):
    super().__init__(base_name, args, module, **observer_kwargs)
    self.avg_constant = self.args.observer_kwargs.get("averaging_constant", 0.01)

    self.past_min_vals = None
    self.past_max_vals = None
    self.past_global_min_vals = None
    self.past_global_max_vals = None

get_current_global_min_max abstractmethod

get_current_global_min_max(observed: Tensor) -> MinMaxTuple

计算观测值的最小值和最大值(不带移动平均),用于全局尺度计算

源代码位于 llmcompressor/observers/moving_base.py
@abstractmethod
def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
    """
    Calculate the min and max value of the observed value (without moving average)
    for the purposes of global scale calculation
    """
    raise NotImplementedError()

get_current_min_max abstractmethod

get_current_min_max(observed: Tensor) -> MinMaxTuple

计算观测值的最小值和最大值(不带移动平均)

源代码位于 llmcompressor/observers/moving_base.py
@abstractmethod
def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
    """
    Calculate the min and max value of the observed value (without moving average)
    """
    raise NotImplementedError()

get_global_min_max

get_global_min_max(observed: Tensor) -> MinMaxTuple

计算从观测值中获取的最小值和最大值的移动平均,用于全局尺度计算

参数

  • observed

    (Tensor) –

    被观测的值,形状为 (num_observations, 1, group_size)

返回

  • MinMaxTuple

    最小值和最大值,其形状为 (1, )

源代码位于 llmcompressor/observers/moving_base.py
def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
    """
    Calculate moving average of min and max values from observed value
    for the purposes of global scale calculation

    :param observed: value being observed whose shape is
        (num_observations, 1, group_size)
    :return: minimum value and maximum value whose shapes are (1, )
    """
    min_vals, max_vals = self.get_current_global_min_max(observed)

    if self.past_global_min_vals is not None and self.avg_constant != 1.0:
        # FUTURE: consider scaling by num observations (first dim)
        #         rather than reducing by first dim
        min_vals = self._lerp(
            self.past_global_min_vals, min_vals, self.avg_constant
        )
        max_vals = self._lerp(
            self.past_global_max_vals, max_vals, self.avg_constant
        )

    self.past_global_min_vals = min_vals
    self.past_global_max_vals = max_vals

    return min_vals, max_vals

get_min_max

get_min_max(observed: Tensor) -> MinMaxTuple

计算从观测值中获取的最小值和最大值的移动平均

参数

  • observed

    (Tensor) –

    被观测的值,形状为 (num_observations, *qparam_shape, group_size)

返回

  • MinMaxTuple

    最小值和最大值,其形状为 (*qparam_shape, )

源代码位于 llmcompressor/observers/moving_base.py
def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
    """
    Calculate moving average of min and max values from observed value

    :param observed: value being observed whose shape is
        (num_observations, *qparam_shape, group_size)
    :return: minimum value and maximum value whose shapes are (*qparam_shape, )
    """
    min_vals, max_vals = self.get_current_min_max(observed)

    if self.past_min_vals is not None and self.avg_constant != 1.0:
        # FUTURE: consider scaling by num observations (first dim)
        #         rather than reducing by first dim
        min_vals = self._lerp(self.past_min_vals, min_vals, self.avg_constant)
        max_vals = self._lerp(self.past_max_vals, max_vals, self.avg_constant)

    self.past_min_vals = min_vals
    self.past_max_vals = max_vals

    return min_vals, max_vals