跳到内容

llmcompressor.pipelines.sequential.transformers_helpers

  • HFCacheProxy

    代理,代表 transformers.cache_utils.Cache 的一个实例。

  • HFProxy

    使用元数据来处理数据依赖型控制流的代理。

  • HFProxyableClassMeta

    元类,用于创建一个类,并将其主要方法包装为可代理。

  • HFTracer

    能够对库中的模型进行符号跟踪的跟踪器。为此,它使用 HFProxy 而不是常规的 PyTorch torch.fx.Proxy。

函数

HFCacheProxy

基类:HFProxy

代理,代表 transformers.cache_utils.Cache 的一个实例。

HFProxy

基类:Proxy

使用元数据来处理数据依赖型控制流的代理。

HFProxyableClassMeta

基类:type

元类,用于创建一个类,并将其主要方法包装为可代理。

HFTracer

HFTracer(autowrap_modules=(math,), autowrap_functions=())

基类:Tracer

能够对库中的模型进行符号跟踪的跟踪器。为此,它使用 HFProxy 而不是常规的 PyTorch torch.fx.Proxy。

方法

  • keys

    当代理对象的 keys() 方法被调用时调用。

  • path_of_module

    辅助方法,用于查找 modroot 的 Module 层级结构中的限定名称。例如,如果 root 有一个名为 foo 的子模块,而 foo 有一个名为 bar 的子模块,则将 bar 传递给此函数将返回字符串 "foo.bar"。

  • trace

    跟踪 root 并返回相应的 FX torch.fx.Graph 表示。root 可以是 torch.nn.Module 实例或 Python 可调用对象。

源代码在 llmcompressor/pipelines/sequential/transformers_helpers.py
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
    super().__init__(
        autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions
    )

keys

keys(obj: Proxy) -> Any

当代理对象的 keys() 方法被调用时调用。当在代理对象上调用 ** 时会发生这种情况。如果希望 ** 在您的自定义跟踪器中正常工作,应返回一个迭代器。

源代码在 llmcompressor/pipelines/sequential/transformers_helpers.py
@compatibility(is_backward_compatible=True)
def keys(self, obj: "Proxy") -> Any:
    """Called when a proxy object is has the keys() method called.
    This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
    your custom tracer.
    """
    attribute = HFAttribute(obj, "keys")()
    if obj.node.target.startswith("**"):
        return attribute._metadata
    return attribute

path_of_module

path_of_module(mod: Module) -> str

辅助方法,用于查找 modroot 的 Module 层级结构中的限定名称。例如,如果 root 有一个名为 foo 的子模块,而 foo 有一个名为 bar 的子模块,则将 bar 传递给此函数将返回字符串 "foo.bar"。

参数: mod (str): 要获取限定名称的 Module

源代码在 llmcompressor/pipelines/sequential/transformers_helpers.py
def path_of_module(self, mod: nn.Module) -> str:
    """
    Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
    a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
    string "foo.bar".

    Args:
        mod (str): The `Module` to retrieve the qualified name for.
    """
    try:
        return super().path_of_module(mod)
    except NameError as e:
        if (
            self.allow_insert_stateless_mods
            and len(list(mod.parameters())) == 0
            and len(list(mod.buffers())) == 0
        ):
            path = self._insert_module_as_submodule(mod)
            return path
        raise e

trace

trace(
    root: Module | Callable[..., Any],
    concrete_args: dict[str, Any] | None = None,
    dummy_inputs: dict[str, Any] | None = None,
    complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph

跟踪 root 并返回相应的 FX torch.fx.Graph 表示。root 可以是 torch.nn.Module 实例或 Python 可调用对象。请注意,在此调用之后,self.root 可能与此处传入的 root 不同。例如,当一个自由函数被传递给 trace() 时,我们将创建一个 torch.nn.Module 实例作为根并添加嵌入的常量。

参数: root (torch.nn.ModuleCallable): 要跟踪的 torch.nn.Module 或函数。如果 root 不是 [~transformers.PreTrainedModel],则必须传入 dummy_inputs,否则跟踪将失败。 concrete_args (dict[str, Any], 可选): 不应被视为代理的具体参数。 dummy_inputs (dict[str, Any], 可选): 在 root 不是 [~transformers.PreTrainedModel] 时处理数据依赖型控制流所需的虚拟输入。当 root 是 [~transformers.PreTrainedModel] 时,也可以使用它来为部分或所有模型输入指定自定义虚拟输入。 complete_concrete_args_with_inputs_not_in_dummy_inputs (bool, 可选, 默认为 True): 如果为 True 且指定了 dummy_inputs,则 root 可以接受的、不在 dummy_inputsconcrete_args 中的每个参数都将被添加到 concrete_args 中,否则不执行任何操作。

返回: torch.fx.Graph: 一个 FX torch.fx.Graph,表示传入的 root 的语义。

源代码在 llmcompressor/pipelines/sequential/transformers_helpers.py
def trace(
    self,
    root: torch.nn.Module | Callable[..., Any],
    concrete_args: dict[str, Any] | None = None,
    dummy_inputs: dict[str, Any] | None = None,
    complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
    """
    Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
    `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
    the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
    `torch.nn.Module` instance to use as the root and add embedded constants to.

    Args:
        root (`torch.nn.Module` or  `Callable`):
            Either a `torch.nn.Module`` or a function to be traced through. If root is not a
            [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
        concrete_args (`dict[str, Any], *optional*):
            Concrete arguments that should not be treated as Proxies
        dummy_inputs (`dict[str, Any]`, *optional*):
            The dummy inputs needed to handle data-dependent control-flow if `root` is not a
            [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
            [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
        complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
            If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
            `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

    Returns:
        `torch.fx.Graph`:
            A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

    """
    sig = inspect.signature(
        root.forward if isinstance(root, torch.nn.Module) else root
    )

    if concrete_args is None:
        concrete_args = {}

    if (
        dummy_inputs is not None
        and complete_concrete_args_with_inputs_not_in_dummy_inputs
    ):
        for param in sig.parameters.values():
            if param.name in dummy_inputs:
                continue
            if param.default is inspect.Parameter.empty:
                raise ValueError(
                    f"You need to specify a default value for the parameter {param.name}."
                )
        concrete_args.update(
            {
                p.name: p.default
                for p in sig.parameters.values()
                if (p.name not in dummy_inputs and p.name not in concrete_args)
            }
        )

    input_names = sig.parameters.keys() - concrete_args.keys()

    # Creating a random input shape to generate dummy inputs.
    batch_size = _generate_random_int()
    sequence_length = _generate_random_int()
    shape = [batch_size, sequence_length]

    if root.__class__.__name__ in get_values(
        MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
    ):
        num_choices = _generate_random_int(low=2, high=5)
        shape.insert(1, num_choices)

    inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
    for input_name in input_names:
        if input_name in inputs:
            continue
        # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
        # be able to use HFTracer._generate_dummy_input.
        if isinstance(root, self.supported_archs) or type(
            root
        ).__qualname__.startswith(("_deserialize_graph_module", "_CodeOnlyModule")):
            inputs.update(
                self._generate_dummy_input(
                    root, input_name, shape, input_names=input_names
                )
            )
        else:
            raise RuntimeError(
                f"Could not generate input named {input_name} for because root is not a"
                " transformers.PreTrainedModel."
            )

    def to_meta(value):
        if isinstance(value, torch.Tensor):
            return value.to("meta")
        return value

    concrete_metas = pytree.tree_map(to_meta, inputs)

    for param in sig.parameters.values():
        if (
            param.kind == inspect.Parameter.VAR_KEYWORD
            and param.name not in input_names
        ):
            concrete_metas[f"**{param.name}"] = {}
    self.meta_args = concrete_metas

    global _CURRENT_TRACER
    _CURRENT_TRACER = self
    with self.patch_for_tracing(root):
        try:
            self.graph = super().trace(root, concrete_args=concrete_args)
        finally:
            _CURRENT_TRACER = None

    # This is necessary because concrete args are added as input to the traced module since
    # https://github.com/pytorch/pytorch/pull/55888.
    for node in self.graph.nodes:
        if node.op == "placeholder":
            # Removing default values for inputs as the forward pass will fail with them.
            if node.target in input_names:
                node.args = ()
                # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
                # It cannot infer on the attributes and methods the input should have, and fails.
                node.type = torch.Tensor
            # It is a concrete arg so it is not used and should be removed.
            else:
                to_visit = [node]
                to_delete = collections.OrderedDict()
                while to_visit:
                    n = to_visit.pop(0)
                    to_delete[n] = None
                    to_visit += list(n.users.keys())

                for user in reversed(to_delete.keys()):
                    self.graph.erase_node(user)

        # TODO: solves GraphModule creation.
        # Without this, return type annotation "Tuple" is causing code execution failure.
        if node.op == "output":
            node.type = None

    return self.graph

gen_constructor_wrapper

gen_constructor_wrapper(
    target: Callable,
) -> tuple[Callable, Callable]

包装 target 以使其可代理。用于张量创建器,如 torch.onestorch.arange 等。

源代码在 llmcompressor/pipelines/sequential/transformers_helpers.py
def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
    """
    Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
    """
    wrapper = create_wrapper(target, "call_function")
    return wrapper, target

symbolic_trace

symbolic_trace(
    model: PreTrainedModel,
    input_names: list[str] | None = None,
    disable_check: bool = False,
    tracer_cls: type[HFTracer] = HFTracer,
) -> GraphModule

对模型执行符号跟踪。

参数: model ([PretrainedModel]): 要跟踪的模型。 input_names (list[str], 可选): 被跟踪模型的输入的名称。如果未设置,则改用 model.dummy_inputs.keys()。 disable_check (bool, 可选, 默认为 False): 如果为 True,则在尝试跟踪模型之前不进行任何检查,这主要用于调试目的。 tracer_cls (Type[HFTracer], 可选, 默认为 HFTracer): 用于实例化跟踪器的跟踪器类。如果未设置,则改用 HFTracer

返回: torch.fx.GraphModule: 通过记录跟踪模型时看到的运算来构建的 GraphModule。

示例

```python
from transformers.utils.fx import symbolic_trace

traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
```
源代码在 llmcompressor/pipelines/sequential/transformers_helpers.py
def symbolic_trace(
    model: "PreTrainedModel",
    input_names: list[str] | None = None,
    disable_check: bool = False,
    tracer_cls: type[HFTracer] = HFTracer,
) -> GraphModule:
    """
    Performs symbolic tracing on the model.

    Args:
        model ([`PretrainedModel`]):
            The model to trace.
        input_names (`list[str]`, *optional*):
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
        tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
            The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.

    Returns:
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

        ```python
        from transformers.utils.fx import symbolic_trace

        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
    if input_names is None:
        input_names = model.dummy_inputs.keys()

    input_names = list(input_names)
    concrete_args = get_concrete_args(model, input_names)

    if not disable_check:
        check_if_model_is_supported(model)

    if "past_key_values" in input_names and not getattr(
        model.config, "use_cache", False
    ):
        logger.warning(
            "`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
            "unexpected behavior."
        )
    if "past_key_values" not in input_names and getattr(
        model.config, "use_cache", False
    ):
        logger.warning(
            "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
            "model.config.use_cache = False."
        )
        model.config.use_cache = False

    # Tracing.
    tracer = tracer_cls()
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
    traced = torch.fx.GraphModule(model, traced_graph)

    traced.config = model.config
    # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
    # _generate_dummy_input, where the model class is needed.
    traced.class_for_deserialization = model.__class__
    traced.device = model.device

    return traced