跳到内容

llmcompressor.modifiers.awq.mappings

  • AWQMapping

    存储要平滑的激活配置的数据类

函数

AWQMapping dataclass

AWQMapping(smooth_layer: str, balance_layers: list[str])

存储要平滑的激活配置的数据类。smooth_layer 的输出激活是 balance_layers 的输入激活。

AWQMappings 会被解析成 ResolvedMappings,后者在运行时保留指向实际 torch.nn.Modules 和附加元数据的指针。

ResolvedMapping dataclass

ResolvedMapping(
    smooth_name: str,
    smooth_layer: Module,
    balance_layers: list[Module],
    balance_names: list[str],
    parent: Module,
    parent_name: str,
)

用于存储激活层与在平滑期间必须平衡的后续权重之间的已解析映射的数据类。

参数

  • smooth_name

    (str) –

    激活层的名称

  • smooth_layer

    (Module) –

    存储激活层的 PyTorch 模块

  • balance_layers

    (list[Module]) –

    smooth_layer 输入到的 PyTorch 模块列表,必须进行平衡以抵消 smooth_layer 的平滑

  • balance_names

    (list[str]) –

    可选的 balance_layers 名称列表。

  • parent

    (Module) –

    balance_layers 的父模块。

  • parent_name

    (str) –

    父模块的名称。

get_layer_mappings_from_architecture

get_layer_mappings_from_architecture(
    architecture: str,
) -> list[AWQMapping]

参数

  • architecture

    (str) –

    str: 模型的架构

返回

  • list[AWQMapping]

    list: 给定架构的层映射

llmcompressor/modifiers/awq/mappings.py 中的源代码。
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:
    """
    :param architecture: str: The architecture of the model
    :return: list: The layer mappings for the given architecture
    """

    if architecture not in AWQ_MAPPING_REGISTRY:
        logger.info(
            f"Architecture {architecture} not found in mappings. "
            f"Using default mappings: {_default_mappings}"
        )

    return AWQ_MAPPING_REGISTRY.get(architecture, _default_mappings)