跳到内容

llmcompressor.utils.fsdp.helpers

函数

  • get_fsdp_parent

    获取层名称(layer_name)的最近一个被FSDP封装的父级。如果没有FSDP封装器,

  • is_fsdp_model

    检查一个模型实例是否被FSDP封装

  • maybe_get_wrapped

    给定一个可能已封装也可能未封装分布式封装器的模型,返回底层

  • set_wrapped_model

    给定一个包含可能已封装也可能未封装分布式封装器的模型的状态,设置

get_fsdp_parent

get_fsdp_parent(
    layer_name: str, model: Module
) -> Optional[Module]

获取层名称(layer_name)的最近一个被FSDP封装的父级。如果没有找到FSDP封装器,则返回None。

:model: 要搜索的PyTorch模块

参数

  • layer_name

    (str) –

    模型中要获取父级的层名称

返回

  • Optional[Module]

    层名称(layer_name)的FSDP封装父级(如果可用),否则为None。

源代码位于 llmcompressor/utils/fsdp/helpers.py
def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]:
    """
    Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper
    is found just return None

    :param layer_name: layer name in model to get parent of
    :model: pytorch module to search through
    :return: FSDP wrapped parent of layer_name if available, otherwise None
    """
    if not is_fsdp_model(model):
        return None

    parent_name = layer_name
    parent = operator.attrgetter(parent_name)(model)
    while not isinstance(parent, FullyShardedDataParallel):
        if len(parent_name) == 0:  # we've reached the root module and its not FSDP
            # this should never get hit because we check for an FSDP root above
            # but while statements without a backup are too scary
            return None
        parent_name = ".".join(parent_name.split(".")[:-1])
        parent = operator.attrgetter(parent_name)(model)

    return parent

is_fsdp_model

is_fsdp_model(model: Module) -> bool

检查一个模型实例是否被FSDP封装

参数

  • model

    (Module) –

    要检查的PyTorch模型

返回

  • bool

    如果模块被封装,则为True,否则为False。

源代码位于 llmcompressor/utils/fsdp/helpers.py
def is_fsdp_model(model: Module) -> bool:
    """
    Check if a model instance is wrapped by FSDP

    :param model: pytorch model to check
    :return: True if module is wrapped, False otherwise
    """
    if not FullyShardedDataParallel:
        return False

    return isinstance(model, FullyShardedDataParallel)

maybe_get_wrapped

maybe_get_wrapped(model: Module) -> Module

给定一个可能已封装也可能未封装分布式封装器的模型,返回底层的封装模型。

参数

  • model

    (Module) –

    从中获取封装模型的输入模型

返回

  • 模块

    封装的模型

源代码位于 llmcompressor/utils/fsdp/helpers.py
def maybe_get_wrapped(model: Module) -> Module:
    """
    Given a model that may or may not have a distributed wrapper, return the underlying
    wrapped model.

    :param model: input model to get wrapped model from
    :returns: wrapped model
    """
    if is_fsdp_model(model=model):
        return model._fsdp_wrapped_module
    return model

set_wrapped_model

set_wrapped_model(state: State, wrapped_model: Module)

给定一个包含可能已封装也可能未封装分布式封装器的模型的状态,设置底层的封装模型。

参数

  • state

    (State) –

    要更新模型的状态

  • updated_wrapped

    要注入到input_model中的模型

源代码位于 llmcompressor/utils/fsdp/helpers.py
def set_wrapped_model(state: State, wrapped_model: Module):
    """
    Given a state with a model that may or may not have a distributed wrapper, set
    the underlying wrapped model.

    :param state: state to update model of
    :param updated_wrapped: model to inject into input_model
    """
    if is_fsdp_model(state.model):
        state.model._fsdp_wrapped_module = wrapped_model
    else:
        state.model = wrapped_model