跳到内容

llmcompressor.utils.transformers

函数

  • get_embeddings

    返回模型的输入和输出嵌入。如果 get_input_embeddings/

  • targets_embeddings

    如果给定的目标指向模型的词嵌入,则返回 True

  • untie_word_embeddings

    如果可能,解绑词嵌入。如果模型定义中找不到嵌入,此函数将发出警告。

get_embeddings

get_embeddings(
    model: PreTrainedModel,
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]

返回模型的输入和输出嵌入。如果模型上未实现 get_input_embeddings/ get_output_embeddings,则返回 None。

参数

  • model

    (PreTrainedModel) –

    要从中获取嵌入的模型

返回

  • tuple[Module | None, Module | None]

    包含嵌入模块或 None 的元组

源代码位于 llmcompressor/utils/transformers.py
def get_embeddings(
    model: PreTrainedModel,
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
    """
    Returns input and output embeddings of a model. If `get_input_embeddings`/
    `get_output_embeddings` is not implemented on the model, then None will be returned
    instead.

    :param model: model to get embeddings from
    :return: tuple of containing embedding modules or none
    """
    try:
        input_embed = model.get_input_embeddings()

    except (AttributeError, NotImplementedError):
        input_embed = None

    try:
        output_embed = model.get_output_embeddings()
    except (AttributeError, NotImplementedError):
        output_embed = None

    return input_embed, output_embed

targets_embeddings

targets_embeddings(
    model: PreTrainedModel,
    targets: NamedModules,
    check_input: bool = True,
    check_output: bool = True,
) -> bool

如果给定的目标指向模型的词嵌入,则返回 True

参数

  • model

    (PreTrainedModel) –

    包含词嵌入

  • targets

    (NamedModules) –

    要检查的命名模块

  • check_input

    (bool, 默认值: True ) –

    是否检查输入嵌入是否被定向

  • check_output

    (bool, 默认值: True ) –

    是否检查输出嵌入是否被定向

返回

  • bool

    如果嵌入被定向,则返回 True,否则返回 False

源代码位于 llmcompressor/utils/transformers.py
def targets_embeddings(
    model: PreTrainedModel,
    targets: NamedModules,
    check_input: bool = True,
    check_output: bool = True,
) -> bool:
    """
    Returns True if the given targets target the word embeddings of the model

    :param model: containing word embeddings
    :param targets: named modules to check
    :param check_input: whether to check if input embeddings are targeted
    :param check_output: whether to check if output embeddings are targeted
    :return: True if embeddings are targeted, False otherwise
    """
    input_embed, output_embed = get_embeddings(model)
    if (check_input and input_embed) is None or (check_output and output_embed is None):
        logger.warning(
            "Cannot check embeddings. If this model has word embeddings, please "
            "implement `get_input_embeddings` and `get_output_embeddings`"
        )
        return False

    targets = set(module for _, module in targets)
    return (check_input and input_embed in targets) or (
        check_output and output_embed in targets
    )

untie_word_embeddings

untie_word_embeddings(model: PreTrainedModel)

如果可能,解绑词嵌入。如果模型定义中找不到嵌入,此函数将发出警告。

模型配置将更新以反映嵌入现已解绑

参数

  • model

    (PreTrainedModel) –

    包含词嵌入的 transformers 模型

源代码位于 llmcompressor/utils/transformers.py
def untie_word_embeddings(model: PreTrainedModel):
    """
    Untie word embeddings, if possible. This function raises a warning if
    embeddings cannot be found in the model definition.

    The model config will be updated to reflect that embeddings are now untied

    :param model: transformers model containing word embeddings
    """
    input_embed, output_embed = get_embeddings(model)
    if input_embed is None or output_embed is None:
        logger.warning(
            "Cannot untie embeddings. If this model has word embeddings, please "
            "implement `get_input_embeddings` and `get_output_embeddings`"
        )
        return

    # clone data to untie
    for module in (input_embed, output_embed):
        if not has_offloaded_params(module):
            data = module.weight.data
        else:
            data = module._hf_hook.weights_map["weight"]

        requires_grad = module.weight.requires_grad
        untied_param = Parameter(data.clone(), requires_grad=requires_grad)
        register_offload_parameter(module, "weight", untied_param)

    # modify model config
    if hasattr(model.config, "tie_word_embeddings"):
        model.config.tie_word_embeddings = False