跳到内容

llmcompressor.transformers.compression.helpers

函数

infer_sparse_targets_and_ignores

infer_sparse_targets_and_ignores(
    model: Module,
    sparsity_structure: str,
    sparsity_threshold: float,
) -> tuple[list[str], list[str]]

推断模型中用于稀疏压缩的目标层和忽略层

参数

  • model

    (Module) –

    要检查的模型

  • sparsity_structure

    (str) –

    要对照检查的稀疏结构

  • sparsity_threshold

    (float) –

    稀疏性的阈值

返回

  • tuple[list[str], list[str]]

    目标层和忽略层的元组

Source code in llmcompressor/transformers/compression/helpers.py
def infer_sparse_targets_and_ignores(
    model: torch.nn.Module,
    sparsity_structure: str,
    sparsity_threshold: float,
) -> tuple[list[str], list[str]]:
    """
    Infers the target and ignore layers in the given model
    to be used for sparsity compression

    :param model: model to check
    :param sparsity_structure: sparsity structure to check against
    :param sparsity_threshold: threshold for sparsity
    :return: tuple of target and ignore layers
    """

    exhaustive_targets, exhaustive_ignore = _get_sparse_targets_ignore_dicts(
        module=model,
        sparsity_structure=sparsity_structure,
        sparsity_threshold=sparsity_threshold,
    )

    return _reduce_targets_and_ignores_into_lists(
        exhaustive_targets=exhaustive_targets,
        exhaustive_ignore=exhaustive_ignore,
    )

infer_sparsity_structure_from_model

infer_sparsity_structure_from_model(
    model: Module,
) -> str | None

根据模型确定稀疏结构(如果存在)。

参数

  • model

    (Module) –

    要检查稀疏结构的模型

返回

  • str | None

    稀疏结构(字符串或 None)

Source code in llmcompressor/transformers/compression/helpers.py
def infer_sparsity_structure_from_model(model: torch.nn.Module) -> str | None:
    """
    Determines the sparsity structure, if any exists, given the model

    :param model: model to check for sparsity structure
    :return: sparsity structure as a string or None
    """

    # check for the common sparsity structures
    structures = {"2:4"}
    for sparsity_structure in structures:
        linear_modules = get_linear_layers(model)
        offloaded_params = get_state_dict_offloaded_model(model)

        linear_modules_with_sparsity_structure = [
            tensor_follows_mask_structure(offloaded_params[f"{name}.weight"])
            for name in tqdm(
                linear_modules.keys(),
                desc="Checking whether model follows "
                f"{sparsity_structure} sparsity structure",
            )
        ]
        # if the majority of the linear modules follow the sparsity structure
        # we can assume that the model follows the sparsity structure
        # (taking into consideration the fact that some Linear layers like the
        # embedding layer might not be sparse)
        if (
            sum(linear_modules_with_sparsity_structure)
            > len(linear_modules_with_sparsity_structure) * 0.8
        ):
            return sparsity_structure

    return None

infer_sparsity_structure_from_modifiers

infer_sparsity_structure_from_modifiers(
    modifiers: list[Modifier],
) -> str | None

根据修饰符列表确定稀疏结构(如果存在)。

参数

  • 修饰符

    (list[Modifier]) –

    修饰符实例列表。

返回

  • str | None

    稀疏结构(字符串或 None)。

Source code in llmcompressor/transformers/compression/helpers.py
def infer_sparsity_structure_from_modifiers(
    modifiers: list[Modifier],  # noqa E501
) -> str | None:
    """
    Determines the sparsity structure, if any exists, given the list of modifiers.

    :param modifiers: List of modifier instances.
    :return: sparsity structure as a string or None.
    """
    for modifier in modifiers:
        if hasattr(modifier, "mask_structure"):
            return modifier.mask_structure
    return None

is_sparse_compression_target

is_sparse_compression_target(
    module: Module,
    sparsity_threshold: float,
    sparsity_structure: str,
) -> bool

参数

  • module

    (Module) –

    要检查的模块

  • sparsity_threshold

    (float) –

    稀疏性的阈值

  • sparsity_structure

    (str) –

    要对照检查的稀疏结构

返回

  • bool

    该模块是否是稀疏压缩的目标,即如果它是稀疏的并且遵循稀疏结构,则为 True,否则为 False

Source code in llmcompressor/transformers/compression/helpers.py
def is_sparse_compression_target(
    module: torch.nn.Module, sparsity_threshold: float, sparsity_structure: str
) -> bool:
    """
    :param module: module to check
    :param sparsity_threshold: threshold for sparsity
    :param sparsity_structure: sparsity structure to check against
    :return: whether or not the module is a target for sparsity compression,
        i.e True if it is sparse and follows the sparsity structure, else False
    """
    with align_module_device(module):
        result = (
            hasattr(module, "weight")
            and tensor_sparsity(module.weight) >= sparsity_threshold
            and tensor_follows_mask_structure(
                tensor=module.weight, mask=sparsity_structure
            )
        )

    return result

tensor_follows_mask_structure

tensor_follows_mask_structure(
    tensor: Tensor, mask: str = "2:4"
) -> bool

参数

  • tensor

    (Tensor) –

    要检查的张量

  • mask

    (str, default: '2:4' ) –

    要检查的掩码结构,格式为“n:m”,也接受“unstructured”作为有效掩码结构

返回

  • bool

    张量是否遵循掩码结构,True 则遵循,False 则不遵循。注意,某些权重可能碰巧为零,因此我们检查每个大小为 m 的块中至少有 n 个零

Source code in llmcompressor/transformers/compression/helpers.py
def tensor_follows_mask_structure(tensor: torch.Tensor, mask: str = "2:4") -> bool:
    """
    :param tensor: tensor to check
    :param mask: mask structure to check for, in the format "n:m", also accepts
        "unstructured" as a valid mask structure
    :return: True if the tensor follows the mask structure, False otherwise.
        Note, some weights can incidentally be zero, so we check for
        atleast n zeros in each chunk of size m
    """

    if mask.lower().strip() == "unstructured":
        return True

    n, m = tuple(map(int, mask.split(":")))

    # If n or m is 0, then the tensor follows the mask structure
    if n == 0 or m == 0:
        return True
    # Reshape the tensor into chunks of size m
    tensor = tensor.view(-1, m)

    # Count the number of zeros in each chunk
    zero_counts = (tensor == 0).sum(dim=1)

    # Check if the number of zeros in each chunk atleast n
    # Greater than sign is needed as some weights can incidentally
    # be zero
    return torch.all(zero_counts >= n).item()