跳到内容

llmcompressor.pipelines.sequential.ast_utils.auto_wrapper

  • AutoWrapper

    根据以下模式自动包装不可跟踪的代码

AutoWrapper

AutoWrapper(namespace: Dict[str, Any], ignore: List[str])

基类: NodeTransformer

根据以下模式自动包装不可跟踪的代码

以下模式将被自动包装 1. 条件无法静态评估的 If 语句 2. 被忽略的函数 (_update_causal_mask) 3. 星号元组解包 4. 星号参数解包

另请参阅:https://github.com/vllm-project/llm-compressor/pull/1411

方法

  • auto_wrap

    修改 AST,自动包装任何不可跟踪的代码段。要包装的段

  • visit_Call

    包装使用 (4) 可变参数或 (2) 匹配忽略列表的任何函数

  • visit_Delete

    self._local_names 中移除任何已删除的名称,

  • visit_FunctionDef

    移除阻止前向函数重新编译的装饰器

  • visit_If

    尝试静态评估 if 语句的条件。如果

  • visit_Name

    将任何新名称添加到 self._local_names 中,

  • visit_Tuple

    (3) 包装任何使用星号解包的元组

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def __init__(self, namespace: Dict[str, Any], ignore: List[str]):
    self.namespace = namespace
    self.ignore = ignore
    self._wrapper_fn_defs: List[ast.FunctionDef] = list()
    self._local_names = set()
    self._wrapped_counter = 0

auto_wrap

auto_wrap(tree: Module) -> ast.Module

修改 AST,自动包装任何不可跟踪的代码段。通过代码分析和基本模式匹配来确定要包装的代码段

参数

  • tree

    (Module) –

    包含原始前向函数定义的模块

返回

  • Module

    包含添加了包装器函数定义和函数调用的模块

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def auto_wrap(self, tree: ast.Module) -> ast.Module:
    """
    Modify ast by automatically wrapping any untraceable code segments. Segments to
    wrap are determined through analysis of the code and basic pattern matching

    :param tree: module containing a definition to an original forward function
    :return: module with added wrapper function definitions and function calls
    """
    tree = self.visit(tree)
    for fn_def in self._wrapper_fn_defs:
        tree.body.insert(0, fn_def)

    return ast.fix_missing_locations(tree)

visit_Call

visit_Call(node: Call) -> ast.Call

包装使用 (4) 可变参数或 (2) 匹配忽略列表的任何函数

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Call(self, node: ast.Call) -> ast.Call:
    """
    Wrap any functions which use (4) variadic arguments or (2) match the ignore list
    """
    # check for variadic starred
    if any(isinstance(elem, ast.Starred) for elem in node.args):
        return self._wrap_if_possible(node)

    # attempt to evaluate caller and check against ignore list
    try:
        caller = self._eval_expr(node.func)

    except Exception:
        caller = None

    finally:
        if (
            isinstance(caller, (FunctionType, MethodType))
            and caller.__name__ in self.ignore
        ):
            return self._wrap_if_possible(node)

    return super().generic_visit(node)

visit_Delete

visit_Delete(node: Delete)

self._local_names 中移除任何已删除的名称,这些名称用于确定函数包装器的参数

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Delete(self, node: ast.Delete):
    """
    Remove any deleted names from `self._local_names`,
    which are used to determine function wrapper arguments
    """
    ret = super().generic_visit(node)

    for target in node.targets:
        if isinstance(target, ast.Name):
            self._local_names.remove(target.id)

    return ret

visit_FunctionDef

visit_FunctionDef(node: FunctionDef) -> ast.FunctionDef

移除阻止前向函数重新编译的装饰器。例如,add_start_docstrings_to_model_forward

由于 _wrapper_fn_defs 是在 visit 完成后附加的,因此此函数不会影响包装器函数

参数

  • 节点

    (FunctionDef) –

    将要剥离装饰器的函数定义

返回

  • FunctionDef

    没有装饰器的函数定义

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
    """
    Remove decorators which prevent forward function recompilation
    For example, add_start_docstrings_to_model_forward

    Because `_wrapper_fn_defs` are appended after `visit` finishes, this function
    will not affect wrapper functions

    :param node: function definition whose decorators will be stripped
    :return: function definition without decorators
    """
    node.decorator_list = [
        decorator_name
        for decorator_name in node.decorator_list
        if isinstance(decorator_name, ast.Name)
        and decorator_name.id in ("can_return_tuple",)  # modifies func signature
    ]

    if node.name == "forward":
        for arg in node.args.args:
            self._local_names.add(arg.arg)
        for arg in node.args.posonlyargs:
            self._local_names.add(arg.arg)
        for arg in node.args.kwonlyargs:
            self._local_names.add(arg.arg)
        if node.args.vararg:
            self._local_names.add(node.args.vararg.arg)
        if node.args.kwarg:
            self._local_names.add(node.args.kwarg.arg)
    return super().generic_visit(node)

visit_If

visit_If(node: If) -> Union[ast.If, ast.Assign]

尝试静态评估 if 语句的条件。如果条件无法静态评估 (1),则尝试包装 if 语句

参数

  • 节点

    (If) –

    可能被包装的 if 语句

返回

  • Union[If, Assign]

    如果 if 语句无法静态评估,则返回将条件替换为 TrueFalseif 语句。否则,返回一个包装器函数调用

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
    """
    Attempt to statically evaluate the condition of the `if` statement. If the
    condition can not be statically evaluated (1), then attmept to wrap the `if`
    statement

    :param node: `if` statement which may be wrapped
    :return: if the `if` statement cannot be statically evaluated, return the
        `if` statement with the condition replaced by `True` or `False`.
        Otherwise, return a wrapper function call
    """
    try:
        value = bool(self._eval_expr(node.test))

        # force a wrap if any assignments occur within the if statement
        for expr in ast.walk(node):
            if isinstance(expr, ast.NamedExpr):
                raise Exception("If statement contains assignment")

    except Exception:
        return self._wrap_if_possible(node)

    else:
        node.test = ast.Constant(value=value)
        return super().generic_visit(node)

visit_Name

visit_Name(node: Name)

将任何新名称添加到 self._local_names 中,这些名称用于确定函数包装器的参数

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Name(self, node: ast.Name):
    """
    Add any new names in `self._local_names`,
    which are used to determine function wrapper arguments
    """
    if isinstance(node.ctx, ast.Store):
        self._local_names.add(node.id)

    return super().generic_visit(node)

visit_Tuple

visit_Tuple(node: Tuple) -> Union[ast.Tuple, ast.Call]

(3) 包装任何使用星号解包的元组

源代码在 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Tuple(self, node: ast.Tuple) -> Union[ast.Tuple, ast.Call]:
    """
    (3) Wrap any tuples which use starred unpacking
    """
    if any(isinstance(elem, ast.Starred) for elem in node.elts):
        return self._wrap_if_possible(node)

    return super().generic_visit(node)