跳到内容

llmcompressor.pipelines.sequential.helpers

  • Subgraph

    指定模型图的可执行子图的数据类

函数

SequentialTracer

SequentialTracer(
    ancestors: Set[Module], offloaded: Set[Module]
)

Bases: HFTracer

Get a tracer specialized for the given model. The resulting tracer will not trace inside of sequential targets, nor any modules which are not call graph ancestors of sequential targets

Tracing within sequential targets is unnecessary, and tracing within offloaded modules may result in meta tensors being added to the model graph

参数

  • ancestors

    (Set[Module]) –

    modules which are ancestors of sequential targets

  • offloaded

    (Set[Module]) –

    modules which have offloaded params and should not be traced

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
    self.ancestors = ancestors
    self.offloaded = offloaded

    # skip any mask creation functions not already caught by the autowrapper
    super().__init__(autowrap_functions=_get_autowrap_functions())

    # check unlikely case that ancestors have direct params which are offloaded
    offloaded_ancestors = offloaded & ancestors
    for ancestor in offloaded_ancestors:
        remove_hook_from_module(ancestor, recurse=False)
        self.offloaded.remove(ancestor)
        logger.warning(
            f"Direct parameters attached to {ancestor.__class__.__name__} have "
            "been onloaded in order to ensure safe graph capture and execution"
        )

Subgraph dataclass

Subgraph(
    graph: Graph,
    input_names: Set[str],
    consumed_names: Set[str],
    _code: Optional[PythonCode] = None,
)

指定模型图的可执行子图的数据类

参数

  • graph

    (Graph) –

    模型图的子图

  • input_names

    (Set[str]) –

    编译后的 forward 函数的参数名称

  • consumed_names

    (Set[str]) –

    任何后续子图未使用的参数名称,因此可以从中间缓存中删除

方法

  • forward

    执行子图中的操作

forward

forward(*args, **kwargs) -> Dict[str, Any]

执行子图中的操作

参数

  • \*args

    子图 forward 函数的参数输入

  • \**kwargs

    子图 forward 函数的关键字输入

返回

  • Dict[str, Any]
源代码位于 llmcompressor/pipelines/sequential/helpers.py
def forward(self, *args, **kwargs) -> Dict[str, Any]:
    """
    Execute the operations within the subgraph

    :param \\*args: argument inputs to subgraph forward function
    :param \\**kwargs: keyword inputs to subgraph forward function
    :return keyword outputs of subgraph forward function (non-consumed variables):
    """
    if self._code is None:
        self._code = self.graph.python_code("self")
        exec(self._code.src, self._code.globals)

    forward_fn = self._code.globals.get("forward")

    with append_autowrap_source_on_fail():
        return forward_fn(*args, **kwargs)

dispatch_for_sequential

dispatch_for_sequential(
    model: PreTrainedModel,
) -> PreTrainedModel

使用顺序管道调度模型进行顺序校准。模型将被卸载到 CPU,如果可用,则调度到 CUDA/XPU 设备。移除任何现有的钩子。

参数

  • model

    (PreTrainedModel) –

    要分发的模型

返回

  • PreTrainedModel

    已调度的模型

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
    """
    Dispatch a model for sequential calibration using a sequential pipeline.
    The model will be offloaded to the CPU and dispatched to CUDA/XPU device
    if available. Removes any existing hooks.

    :param model: model to dispatch
    :return: dispatched model
    """
    remove_dispatch(model)

    if torch.cuda.is_available():
        offloaded_dispatch(model, execution_device=torch.device("cuda:0"))
    elif hasattr(torch, "xpu") and torch.xpu.is_available():
        offloaded_dispatch(model, execution_device=torch.device("xpu:0"))
    else:
        logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")

    return model

find_target_nodes

find_target_nodes(
    graph: GraphModule, targets: Set[Module]
) -> Set[Node]

Find all nodes whose execution is equivalent to executing the target modules. Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer

参数

  • graph

    (GraphModule) –

    graph containing target nodes

  • targets

    (Set[Module]) –

    modules whose nodes are being searched for

返回

  • Set[Node]

    set of all nodes which call the target modules

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def find_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]:
    """
    Find all nodes whose execution is equivalent to executing the target modules.
    Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer

    :param graph: graph containing target nodes
    :param targets: modules whose nodes are being searched for
    :return: set of all nodes which call the target modules
    """
    return set(
        node
        for node in graph.graph.nodes
        if node.op == "call_module" and graph.get_submodule(node.target) in targets
    )

get_sequential_ancestors

get_sequential_ancestors(
    model: Module, targets: Set[Module]
) -> Set[Module]

Find modules which are call graph ancestors of the given sequential targets

参数

  • model

    (Module) –

    model containing sequential targets

  • targets

    (Set[Module]) –

    sequential targets to find ancestors of

返回

  • Set[Module]

    call graph ancestors of sequential targets

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]:
    """
    Find modules which are call graph ancestors of the given sequential targets

    :param model: model containing sequential targets
    :param targets: sequential targets to find ancestors of
    :return: call graph ancestors of sequential targets
    """
    ancestors = set()

    def is_ancestor(module: Module) -> bool:
        if module in ancestors or module in targets:
            return True

        # eagerly compute list in order to avoid early stopping and :. missing ancestors
        _is_ancestor = any([is_ancestor(child) for child in module.children()])
        if _is_ancestor:
            ancestors.add(module)

        return _is_ancestor

    is_ancestor(model)
    return ancestors

get_sequential_targets

get_sequential_targets(
    modifiers: List[Modifier],
    model: PreTrainedModel,
    args: DatasetArguments,
) -> List[str]

根据修饰符列表和数据集参数推断顺序目标

参数

  • model

    (PreTrainedModel) –

    正在校准的模型

  • 修饰符

    (List[Modifier]) –

    校准期间应用的修饰符列表

  • dataset_args

    用户传递的数据集参数

返回

  • List[str]

    顺序目标列表

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def get_sequential_targets(
    modifiers: List[Modifier], model: PreTrainedModel, args: "DatasetArguments"
) -> List[str]:
    """
    Infer sequential targets from modifiers list and dataset args

    :param model: model being calibrated
    :param modifiers: list of modifiers being applied during calibration
    :param dataset_args: dataset arguments passed by user
    :return: list of sequential targets
    """
    modifier_targets = [
        (modifier, modifier.sequential_targets)
        for modifier in modifiers
        if getattr(modifier, "sequential_targets", None) is not None
    ]

    # deprecation warning
    if len(modifier_targets) >= 1:
        logger.warning(
            "Passing sequential targets through modifiers is deprecated, "
            "please use `oneshot(sequential_targets=...)`"
        )

    # cannot infer from multiple modifiers
    if len(modifier_targets) >= 2:
        types = [type(modifier) for modifier, _ in modifier_targets]
        raise ValueError(
            "Cannot infer sequential targets from multiple sequential modifiers "
            f"({types})"
        )

    # resolve single modifier
    if len(modifier_targets) == 1:
        if args.sequential_targets is not None:
            raise ValueError(
                f"Got sequential targets from both {type(modifier_targets[0][0])} "
                "and dataset arguments `sequential_targets`"
            )

        sequential_targets = modifier_targets[0][1]

    # if no modifiers, use data args
    else:
        sequential_targets = args.sequential_targets  # may be `None`

    # validate and infer
    if sequential_targets is None:
        return get_no_split_params(model)
    elif isinstance(sequential_targets, str):
        return [sequential_targets]
    else:
        return sequential_targets

graph_is_well_formed

graph_is_well_formed(graph: Graph) -> bool

A graph is well formed if and only if nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes

参数

  • graph

    (Graph) –

    graph being checked

返回

  • bool

    True if the graph is well formed, False otherwise

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def graph_is_well_formed(graph: Graph) -> bool:
    """
    A graph is well formed if and only if
    `nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes`

    :param graph: graph being checked
    :return: True if the graph is well formed, False otherwise
    """
    for node in graph.nodes:
        for user in node.users:
            if node not in user.all_input_nodes:
                return False

        for input_node in node.all_input_nodes:
            if node not in input_node.users:
                return False

        if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len(
            set(node.all_input_nodes)
        ):
            return False

    return True

match_modules

match_modules(
    model: Module, target_names: List[str]
) -> Set[Module]

Find modules whose names match the patterns given by target_names

参数

  • model

    (Module) –

    model containing submodules to find

  • target_names

    (List[str]) –

    target patterns to find

返回

  • Set[Module]

    all submodules matching target_names

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
    """
    Find modules whose names match the patterns given by `target_names`

    :param model: model containing submodules to find
    :param target_names: target patterns to find
    :return: all submodules matching `target_names`
    """
    return set(
        module
        for name, module in model.named_modules()
        if match_targets(name, module, target_names)
    )

partition_graph

partition_graph(
    model: Module, partitions: List[List[Node]]
) -> List[Subgraph]

Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping of output node names to their computed values. Note that the consumed_names attribute of each Subgraph remains empty, to be later populated by trace_consumed_names

参数

  • model

    (Module) –

    model which owns the produced Subgraphs

  • partitions

    (List[List[Node]]) –

    list of partitions, where each partition is a list of nodes belonging to that partition

返回

  • List[Subgraph]

    list of subgraphs in order of execution

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgraph]:
    """
    Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping
    of output node names to their computed values. Note that the `consumed_names`
    attribute of each Subgraph remains empty, to be later populated by
    `trace_consumed_names`

    :param model: model which owns the produced Subgraphs
    :param partitions: list of partitions, where each partition is a list of nodes
        belonging to that partition
    :return: list of subgraphs in order of execution
    """
    subgraphs = []

    # create subgraphs
    for partition_nodes in partitions:
        # create a new graph for the partition
        graph = Graph(model)
        node_map = {}

        # add placeholders for inputs not in this subgraph. use set to deduplicate
        new_input_nodes = {
            input_node
            for node in partition_nodes
            for input_node in node.all_input_nodes
            if input_node not in partition_nodes and input_node.op
        }
        for input_node in new_input_nodes:
            node_map[input_node] = graph.placeholder(input_node.name)

        # add the nodes to subgraph
        for node in partition_nodes:
            node_map[node] = graph.node_copy(node, lambda n: node_map[n])

        # add an output node to collect all subgraph outputs into a dictionary
        if len(graph.find_nodes(op="output")) <= 0:
            output_dict = {
                node.name: node_map[node]
                for node in partition_nodes
                if any(user not in partition_nodes for user in node.users.keys())
            }
            graph.output(output_dict)

        # save the subgraph for this partition
        graph.lint()
        input_names = set(node.name for node in graph.nodes if node.op == "placeholder")
        subgraphs.append(
            Subgraph(
                graph=graph,
                input_names=input_names,
                consumed_names=set(),  # populated later
            )
        )

        assert graph_is_well_formed(graph)

    return subgraphs

populate_concrete_args

populate_concrete_args(
    model: Module, sample_input: Dict
) -> Dict

Creates concrete args which, unlike the equivalent function provided by transformers.utils.fx, creates default values for variadic arguments, which are needed by some models.

参数

  • model

    (Module) –

    正在追踪的模型

  • sample_input

    (Dict) –

    values used to symbolically trace the model. All arguments to the model.forward function which are not in the sample_input are considered concrete args

返回

  • Dict

    dictionary mapping concrete argument names to their default values

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def populate_concrete_args(model: Module, sample_input: Dict) -> Dict:
    """
    Creates concrete args which, unlike the equivalent function provided by
    transformers.utils.fx, creates default values for variadic arguments, which are
    needed by some models.

    :param model: model being traced
    :param sample_input: values used to symbolically trace the model. All arguments
        to the model.forward function which are not in the sample_input are considered
        concrete args
    :return: dictionary mapping concrete argument names to their default values
    """
    sig = inspect.signature(model.forward)

    concrete_args = {}
    for parameter in sig.parameters.values():
        if parameter.name in sample_input:
            continue
        if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL:
            value = list()
        elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD:
            value = dict()
        elif parameter.name == "use_cache":
            value = False
        else:
            value = parameter.default

        concrete_args[parameter.name] = value

    return concrete_args

topological_partition

topological_partition(
    graph: GraphModule, targets: Set[Module]
) -> List[List[Node]]

Partition the graph into partitions such that each target belongs to exactly one partition and executing each partition depends only on intermediate values produced by executing the partitions before it.

参数

  • graph

    (GraphModule) –

    graph being partitioned

  • targets

    (Set[Module]) –

    target modules which will be assigned to disjoint partitions

返回

  • List[List[Node]]

    list of partitions, where each partition is a list of nodes belonging to that partition

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]:
    """
    Partition the graph into partitions such that each `target` belongs to exactly one
    partition and executing each partition depends only on intermediate values produced
    by executing the partitions before it.

    :param graph: graph being partitioned
    :param targets: target modules which will be assigned to disjoint partitions
    :return: list of partitions, where each partition is a list of nodes belonging to
        that partition
    """
    assert graph_is_well_formed(graph.graph)
    target_nodes = find_target_nodes(graph, targets)

    partitions: List[List[Node]] = [[]]
    remaining_indegrees = {
        node: len([node for node in node.all_input_nodes if node.op != "get_attr"])
        for node in graph.graph.nodes
    }
    partition_index = 0  # global counter

    # start with graph input nodes,
    # but delay the `get_attr` nodes as long as possible
    queue = deque(
        node
        for node in graph.graph.nodes
        if remaining_indegrees[node] == 0 and node.op != "get_attr"
    )
    while len(queue) > 0:
        node = queue.popleft()

        # assign to partition
        partitions[partition_index].append(node)

        # guarantee targets are assigned to disjoint partitions
        if node in target_nodes:
            partition_index += 1
            partitions.append([])

        # recurse on last indegree only in order to guarantee that
        # the node is assigned to maximal partition
        for user in node.users:
            remaining_indegrees[user] -= 1
            if remaining_indegrees[user] == 0:
                queue.append(user)

    # an ideal implementation would involve implicitly consolidating partition indices
    # so that each node is assigned to the maximum partition possible (in order to delay
    # execution as long as possible), but saving these nodes for last covers the most
    # common and costly case (get_attr)
    for node in graph.graph.find_nodes(op="get_attr"):
        user_partitions = []
        for user in node.users:
            for index in range(len(partitions)):
                if user in partitions[index]:
                    user_partitions.append(index)
                    break

        # workaround
        if len(user_partitions):
            partition_index = min(user_partitions)
            partitions[partition_index].insert(0, node)

    return partitions

trace_consumed_names

trace_consumed_names(subgraphs: List[Subgraph])

Populate the consumed_names attribute of each Subgraph according to when inputs are last used in order to vacate the intermediates cache and save memory

参数

  • subgraphs

    (List[Subgraph]) –

    list of subgraphs with empty consumed_names attributes

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def trace_consumed_names(subgraphs: List[Subgraph]):
    """
    Populate the `consumed_names` attribute of each Subgraph according to when inputs
    are last used in order to vacate the `intermediates` cache and save memory

    :param subgraphs: list of subgraphs with empty `consumed_names` attributes
    """
    # populate consumed_names according to when inputs are last used
    # in order to vacate the `intermediates` cache and save memory
    all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs))
    for input_name in all_input_names:
        for subgraph in reversed(subgraphs):
            if input_name in subgraph.input_names:
                subgraph.consumed_names.add(input_name)
                break
        else:
            raise ValueError(f"Could not find input name {input_name} in subgraphs")

trace_subgraphs

trace_subgraphs(
    model: PreTrainedModel,
    sample_input: Dict[str, Any],
    sequential_targets: List[str],
    ignore: List[str],
) -> List[Subgraph]

追踪模型以生成子图,其中每个顺序目标都恰好属于一个子图,并且按顺序执行每个子图等效于执行原始模型

参数

  • model

    (PreTrainedModel) –

    正在追踪的模型

  • sample_input

    (Dict[str, Any]) –

    在执行期间其值将发生变化的输入,但其 lenboolcontains 值在批次中假定为常量

  • sequential_targets

    (List[str]) –

    匹配顺序目标的模式列表

  • ignore

    (List[str]) –

    追踪期间要跳过的函数和方法名称

返回

  • List[Subgraph]

    按执行顺序排列的子图列表

源代码位于 llmcompressor/pipelines/sequential/helpers.py
def trace_subgraphs(
    model: PreTrainedModel,
    sample_input: Dict[str, Any],
    sequential_targets: List[str],
    ignore: List[str],
) -> List[Subgraph]:
    """
    Trace a model to produce subgraphs, where each sequential target belongs to exactly
    one subgraph and where executing each subgraph in order is equivalent to executing
    the original model

    :param model: model being traced
    :param sample_input: inputs whose values will change during execution but whose
        __len__, __bool__, and __contains__ values are assumed constant across batches
    :param sequential_targets: list of patterns matching sequential targets
    :param ignore: function and method names to skip during tracing
    :return: a list of Subgraphs in order of execution
    """
    # find modules
    targets = match_modules(model, sequential_targets)
    ancestors = get_sequential_ancestors(model, targets)
    offloaded = set(m for m in model.modules() if has_offloaded_params(m))

    # initialize arguments
    tracer = SequentialTracer(ancestors, offloaded)
    concrete_args = populate_concrete_args(model, sample_input)

    with contextlib.ExitStack() as stack:
        # calibration context
        stack.enter_context(calibration_forward_context(model))
        stack.enter_context(HooksMixin.disable_hooks())

        # flags useful for tracing
        stack.enter_context(patch_attr(model.config, "_attn_implementation", "eager"))
        stack.enter_context(patch_attr(torch.compiler, "_is_compiling_flag", True))

        # autowrap forwards
        stack.enter_context(autowrap_forwards(ancestors, ignore))

        # avoid bug where pytorch cannot handle wrapped root functions
        unwrapped = inspect.unwrap(model.forward).__get__(model)
        stack.enter_context(patch_attr(model, "forward", unwrapped))
        stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
        assert isinstance(model.forward, MethodType)
        assert isinstance(type(model).forward, FunctionType)

        with append_autowrap_source_on_fail():
            graph = GraphModule(
                model,
                tracer.trace(
                    model,
                    dummy_inputs=sample_input,
                    concrete_args=concrete_args,
                    complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
                    # bug in trace throws an error for variadic
                    # args and kwargs in function signature
                ),
            )

    # copy metadata
    graph.config = model.config
    graph.class_for_deserialization = model.__class__
    graph.device = model.device

    # perform subgraph partition
    partitions = topological_partition(graph, targets)
    subgraphs = partition_graph(model, partitions)
    trace_consumed_names(subgraphs)

    # As currently implemented, `topological_partition` generates an extra subgraph at
    # the beginning which does not contain a target. This adds a little more runtime,
    # and could be folded into the first subgraph in the future
    if len(subgraphs) != len(targets) + 1:
        logger.warning(
            f"Expected {len(targets)} subgraphs, but only traced {len(subgraphs)}. "
            "This is likely due to having wrapped code which calls sequential targets"
        )

    return subgraphs