跳到内容

llmcompressor.modifiers.utils.hooks

  • HooksMixin

    用于管理钩子注册、禁用和移除的 Mixin。

HooksMixin

基类:BaseModel

用于管理钩子注册、禁用和移除的 Mixin。Modifiers 应使用 self.register_hook(module, hook, hook_type) 进行钩子注册,并使用 self.remove_hooks() 进行移除。

实现钩子的 Modifiers 应使用 self.register_..._hook(module, hook) 进行注册,而不是通常的 module.register_..._hook(hook)。Modifiers 应使用 self.remove_hooks() 进行移除。

钩子可以应用于模块或参数。

典型示例

modifier.register_forward_hook(module, hook) with HooksMixin.disable_hooks(): model.forward(...) modifier.remove_hooks()

激活特定钩子子集的示例

hooks = [modifier.register_forward_hook(module, hook) for module in ...] with HooksMixin.disable_hooks(keep=hooks): model.forward(...) modifier.remove_hooks(hooks)

方法

  • disable_hooks

    禁用所有 Modifiers 中的所有钩子。组合多个上下文等同于

  • register_hook

    在指定的模块/参数上注册一个钩子,并可以选择禁用它。

  • remove_hooks

    移除此 Modifier 注册的钩子。

disable_hooks classmethod

disable_hooks(keep: Set[RemovableHandle] = frozenset())

禁用所有 Modifiers 中的所有钩子。组合多个上下文等同于 keep 参数的并集。

参数

  • keep

    (Set[RemovableHandle], default: frozenset() ) –

    可选的要保持启用的句柄集合。

源文件位于 llmcompressor/modifiers/utils/hooks.py
@classmethod
@contextlib.contextmanager
def disable_hooks(cls, keep: Set[RemovableHandle] = frozenset()):
    """
    Disable all hooks across all modifiers. Composing multiple contexts is
    equivalent to the union of `keep` arguments

    :param keep: optional set of handles to keep enabled
    """
    try:
        cls._HOOKS_DISABLED = True
        cls._HOOKS_KEEP_ENABLED |= keep
        yield
    finally:
        cls._HOOKS_DISABLED = False
        cls._HOOKS_KEEP_ENABLED -= keep

register_hook

register_hook(
    target: Union[Module, Parameter],
    hook: Callable[[Any], Any],
    hook_type: str,
    **kwargs,
) -> RemovableHandle

在指定的模块/参数上注册一个钩子,并可以选择使用 HooksMixin.disable_hooks() 禁用它。

参数

  • target

    (Union[Module, Parameter]) –

    应注册钩子的模块或参数。

  • hook

    (Callable[[Any], Any]) –

    要注册的钩子。

  • hook_type

    (str) –

    要注册的钩子类型,对应于 torch.nn.Module 上的 register_{hook_type}_hook 属性。例如:"forward", "forward_pre", "full_backward", "state_dict_post", ""

  • kwargs

    要传递给 register hook 方法的关键字参数。

源文件位于 llmcompressor/modifiers/utils/hooks.py
def register_hook(
    self,
    target: Union[torch.nn.Module, torch.nn.Parameter],
    hook: Callable[[Any], Any],
    hook_type: str,
    **kwargs,
) -> RemovableHandle:
    """
    Registers a hook on a specified module/parameter with the option to disable it
    with HooksMixin.disable_hooks()

    :param target: the module or parameter on which the hook should be registered
    :param hook: the hook to register
    :param hook_type: the type of hook to register corresponding to the
        `register_{hook_type}_hook` attribute on torch.nn.Module.
        Ex. "forward", "forward_pre", "full_backward", "state_dict_post", ""
    :param kwargs: keyword arguments to pass to register hook method
    """
    handle = None

    @wraps(hook)
    def wrapped_hook(*args, **kwargs):
        nonlocal handle

        if (
            HooksMixin._HOOKS_DISABLED
            and handle not in HooksMixin._HOOKS_KEEP_ENABLED
        ):
            return

        return hook(*args, **kwargs)

    register_function = self._get_register_function(target, hook_type)
    handle = register_function(wrapped_hook, **kwargs)
    self._hooks.add(handle)
    logger.debug(f"{self} added {handle}")

    return handle

remove_hooks

remove_hooks(
    handles: Optional[Set[RemovableHandle]] = None,
)

移除此 Modifier 注册的钩子。

参数

  • handles

    (Optional[Set[RemovableHandle]], default: None ) –

    可选的要移除的句柄列表,默认为此 Modifier 注册的所有钩子。

源文件位于 llmcompressor/modifiers/utils/hooks.py
def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
    """
    Removes hooks registered by this modifier

    :param handles: optional list of handles to remove, defaults to all hooks
        registerd by this modifier
    """
    if handles is None:
        handles = self._hooks

    for hook in handles:
        hook.remove()

    self._hooks -= handles