跳到内容

llmcompressor.pipelines.sequential.ast_utils

模块

  • AutoWrapper

    根据以下模式自动包装不可追踪的代码

  • ControlFlowAnalyzer

    用于确定一段代码是否可以被包装成函数。包括任何

  • NameAnalyzer

    确定与代码片段关联的未绑定、已分配和条件分配的名称

AutoWrapper

AutoWrapper(namespace: Dict[str, Any], ignore: List[str])

基类:NodeTransformer

根据以下模式自动包装不可追踪的代码

以下模式会自动包装 1. 条件无法静态评估的 If 语句 2. 被忽略的函数 (_update_causal_mask) 3. 星号元组解包 4. 星号参数解包

另请参阅:https://github.com/vllm-project/llm-compressor/pull/1411

方法

  • auto_wrap

    通过自动包装任何不可追踪的代码段来修改 ast。要包装的段落通过代码分析和基本模式匹配来确定

  • visit_Call

    包装使用 (4) 可变参数或 (2) 匹配忽略列表的任何函数

  • visit_Delete

    self._local_names 中移除任何已删除的名称,

  • visit_FunctionDef

    移除阻止函数前向重新编译的装饰器

  • visit_If

    尝试静态评估 if 语句的条件。如果

  • visit_Name

    添加 self._local_names 中的任何新名称,

  • visit_Tuple

    (3) 包装任何使用星号解包的元组

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def __init__(self, namespace: Dict[str, Any], ignore: List[str]):
    self.namespace = namespace
    self.ignore = ignore
    self._wrapper_fn_defs: List[ast.FunctionDef] = list()
    self._local_names = set()
    self._wrapped_counter = 0

auto_wrap

auto_wrap(tree: Module) -> ast.Module

通过自动包装任何不可追踪的代码段来修改 ast。要包装的代码段通过代码分析和基本模式匹配来确定

参数

  • tree

    (Module) –

    包含原始前向函数定义的模块

返回

  • Module

    包含包装函数定义和函数调用的模块

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def auto_wrap(self, tree: ast.Module) -> ast.Module:
    """
    Modify ast by automatically wrapping any untraceable code segments. Segments to
    wrap are determined through analysis of the code and basic pattern matching

    :param tree: module containing a definition to an original forward function
    :return: module with added wrapper function definitions and function calls
    """
    tree = self.visit(tree)
    for fn_def in self._wrapper_fn_defs:
        tree.body.insert(0, fn_def)

    return ast.fix_missing_locations(tree)

visit_Call

visit_Call(node: Call) -> ast.Call

包装使用 (4) 可变参数或 (2) 匹配忽略列表的任何函数

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Call(self, node: ast.Call) -> ast.Call:
    """
    Wrap any functions which use (4) variadic arguments or (2) match the ignore list
    """
    # check for variadic starred
    if any(isinstance(elem, ast.Starred) for elem in node.args):
        return self._wrap_if_possible(node)

    # attempt to evaluate caller and check against ignore list
    try:
        caller = self._eval_expr(node.func)

    except Exception:
        caller = None

    finally:
        if (
            isinstance(caller, (FunctionType, MethodType))
            and caller.__name__ in self.ignore
        ):
            return self._wrap_if_possible(node)

    return super().generic_visit(node)

visit_Delete

visit_Delete(node: Delete)

self._local_names 中移除任何已删除的名称,这些名称用于确定函数包装器的参数

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Delete(self, node: ast.Delete):
    """
    Remove any deleted names from `self._local_names`,
    which are used to determine function wrapper arguments
    """
    ret = super().generic_visit(node)

    for target in node.targets:
        if isinstance(target, ast.Name):
            self._local_names.remove(target.id)

    return ret

visit_FunctionDef

visit_FunctionDef(node: FunctionDef) -> ast.FunctionDef

移除阻止函数前向重新编译的装饰器 例如,add_start_docstrings_to_model_forward

由于 _wrapper_fn_defsvisit 完成后追加,因此此函数不会影响包装函数

参数

  • 节点

    (FunctionDef) –

    将被剥离装饰器的函数定义

返回

  • FunctionDef

    没有装饰器的函数定义

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
    """
    Remove decorators which prevent forward function recompilation
    For example, add_start_docstrings_to_model_forward

    Because `_wrapper_fn_defs` are appended after `visit` finishes, this function
    will not affect wrapper functions

    :param node: function definition whose decorators will be stripped
    :return: function definition without decorators
    """
    node.decorator_list = [
        decorator_name
        for decorator_name in node.decorator_list
        if isinstance(decorator_name, ast.Name)
        and decorator_name.id in ("can_return_tuple",)  # modifies func signature
    ]

    if node.name == "forward":
        for arg in node.args.args:
            self._local_names.add(arg.arg)
        for arg in node.args.posonlyargs:
            self._local_names.add(arg.arg)
        for arg in node.args.kwonlyargs:
            self._local_names.add(arg.arg)
        if node.args.vararg:
            self._local_names.add(node.args.vararg.arg)
        if node.args.kwarg:
            self._local_names.add(node.args.kwarg.arg)
    return super().generic_visit(node)

visit_If

visit_If(node: If) -> Union[ast.If, ast.Assign]

尝试静态评估 if 语句的条件。如果条件无法静态评估 (1),则尝试包装 if 语句

参数

  • 节点

    (If) –

    可能被包装的 if 语句

返回

  • Union[If, Assign]

    如果 if 语句无法静态评估,则返回 if 语句,并将条件替换为 TrueFalse。否则,返回一个包装函数调用

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
    """
    Attempt to statically evaluate the condition of the `if` statement. If the
    condition can not be statically evaluated (1), then attmept to wrap the `if`
    statement

    :param node: `if` statement which may be wrapped
    :return: if the `if` statement cannot be statically evaluated, return the
        `if` statement with the condition replaced by `True` or `False`.
        Otherwise, return a wrapper function call
    """
    try:
        value = bool(self._eval_expr(node.test))

        # force a wrap if any assignments occur within the if statement
        for expr in ast.walk(node):
            if isinstance(expr, ast.NamedExpr):
                raise Exception("If statement contains assignment")

    except Exception:
        return self._wrap_if_possible(node)

    else:
        node.test = ast.Constant(value=value)
        return super().generic_visit(node)

visit_Name

visit_Name(node: Name)

添加 self._local_names 中的任何新名称,这些名称用于确定函数包装器的参数

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Name(self, node: ast.Name):
    """
    Add any new names in `self._local_names`,
    which are used to determine function wrapper arguments
    """
    if isinstance(node.ctx, ast.Store):
        self._local_names.add(node.id)

    return super().generic_visit(node)

visit_Tuple

visit_Tuple(node: Tuple) -> Union[ast.Tuple, ast.Call]

(3) 包装任何使用星号解包的元组

源代码位于 llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Tuple(self, node: ast.Tuple) -> Union[ast.Tuple, ast.Call]:
    """
    (3) Wrap any tuples which use starred unpacking
    """
    if any(isinstance(elem, ast.Starred) for elem in node.elts):
        return self._wrap_if_possible(node)

    return super().generic_visit(node)

ControlFlowAnalyzer

基类:NodeVisitor

用于确定一段代码是否可以被包装成函数。包括任何包含 returncontinuebreakawaityield 而没有适当上下文的代码。

例如,这段代码可以被包装

while True:
    if some_condition:
        break

而内部代码在没有 while 上下文的情况下无法被包装

def wrapped():
    if some_condition:
        break  # this control statement is now invalid

while True:
    wrapped()

方法

  • is_valid

    如果节点包含不在其

is_valid

is_valid(node: AST) -> bool

如果节点包含不在其 proper 控制流上下文中的控制语句,则返回 False

参数

  • 节点

    (AST) –

    要分析的代码

返回

  • bool

    当代码不包含无效的控制语句时为 True

源代码位于 llmcompressor/pipelines/sequential/ast_utils/control_flow_analyzer.py
def is_valid(self, node: ast.AST) -> bool:
    """
    Returns False if a node contains control statements that are not in their
    proper control flow context

    :param node: code to analyze
    :return: True iff the code does not contain invalid control statements
    """
    self._contexts = []
    self._is_valid = True
    self.visit(node)
    return self._is_valid

NameAnalyzer

NameAnalyzer(omit: Set[str])

基类:NodeVisitor

确定代码片段中未绑定、已分配和条件分配的名称。此信息用于确定包装函数的参数和返回值

例如,对于以下代码片段

b = a + 1
if some_condition:
    c = 5

a 未绑定,意味着它必须是包装函数的输入 b 已分配,意味着它必须是包装函数的输出 c 是条件分配的,意味着它必须是包装函数的输出,并且可能是输入,前提是 c 已经在命名空间中存在

请注意,在读取之前分配给的名称不被视为未绑定

a = 2  # no longer unbound
b = a + 1

方法

  • analyze

    分析给定代码片段中名称的使用情况

源代码位于 llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py
def __init__(self, omit: Set[str]):
    self._omit = builtins.__dict__.keys() | omit

analyze

analyze(node: AST) -> Tuple[Set[str], Set[str], Set[str]]

分析给定代码片段中名称的使用情况

参数

  • 节点

    (AST) –

    要分析的代码 返回:未绑定名称、已分配名称和条件分配名称的元组

源代码位于 llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py
def analyze(self, node: ast.AST) -> Tuple[Set[str], Set[str], Set[str]]:
    """
    Analyzes the use of names in the given piece of code

    :param node: code to analyze
    return: tuple of unbound names, assigned names, and conditionally assigned names
    """
    self._unbound = set()
    self._assigned = set()
    self._conditionally_assigned = set()
    self.visit(node)

    return self._unbound, self._assigned, self._conditionally_assigned