跳到内容

llmcompressor.modifiers

用于应用各种优化技术的压缩修改器。

提供核心修改器系统,用于将量化、剪枝、蒸馏和其他优化方法等压缩技术应用于神经网络。包括基础类、工厂模式和用于可扩展压缩工作流的接口。

模块

Modifier

基类: ModifierInterface, HooksMixin

所有修改器继承的基类。修改器用于修改模型的训练过程。定义了所有修改器可用的基本属性和方法。

生命周期: 1. 初始化 2. on_event -> * 如果 self.start <= event.current_index 则 on_start * 如果 self.end >= event.current_index 则 on_end 5. 最终化

参数

  • index

    修改器在模型修改器列表中的索引

  • group

    修改器的组名

  • start

    修改器的起始步长

  • 结束

    修改器的结束步长

  • 更新

    修改器的更新步长

方法

  • finalize

    为给定模型和状态最终化修改器。

  • initialize

    为给定模型和状态初始化修改器。

  • on_end

    on_end 在修改器结束时调用,必须实现。

  • on_event

    on_event 在事件触发时调用。

  • on_finalize

    on_finalize 在修改器最终化时调用。

  • on_initialize

    on_initialize 在修改器初始化时调用。

  • on_start

    on_start 在修改器开始时调用。

  • on_update

    on_update 在模型需要更新时调用。

  • should_end

    :param event: 检查修改器是否应该结束的事件

  • should_start

    :param event: 检查修改器是否应该开始的事件

  • update_event

    根据给定事件更新修改器。进而调用

属性

  • finalized (bool) –

    :return: 如果修饰符已最终确定,则为 True

  • initialized (bool) –

    :return: 如果修饰符已初始化,则为 True

finalized property

finalized: bool

返回

  • bool

    如果修饰符已最终确定,则为 True

initialized property

initialized: bool

返回

  • bool

    如果修饰符已初始化,则为 True

finalize

finalize(state: State, **kwargs)

为给定模型和状态最终化修改器。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    最终化修改器的附加参数

引发

  • RuntimeError

    如果修改器未初始化

源代码位于 llmcompressor/modifiers/modifier.py
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has not been initialized
    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    """
    if self.finalized_:
        raise RuntimeError("cannot finalize a modifier twice")

    if not self.initialized_:
        raise RuntimeError("cannot finalize an uninitialized modifier")

    # TODO: all finalization should succeed
    self.finalized_ = self.on_finalize(state=state, **kwargs)

initialize

initialize(state: State, **kwargs)

为给定模型和状态初始化修改器。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    初始化修改器的附加参数

引发

  • RuntimeError

    如果修改器已最终化

源代码位于 llmcompressor/modifiers/modifier.py
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has already been finalized
    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    """
    if self.initialized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been initialized"
        )

    if self.finalized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been finalized"
        )

    self.initialized_ = self.on_initialize(state=state, **kwargs)

    # trigger starts
    fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
    if self.should_start(fake_start_event):
        self.on_start(state, fake_start_event, **kwargs)
        self.started_ = True

on_end

on_end(state: State, event: Event, **kwargs)

on_end 在修改器结束时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发结束的事件

  • kwargs

    结束修改器的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    on_end is called when the modifier ends and must be implemented
    by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the end
    :param kwargs: Additional arguments for ending the modifier
    """
    pass

on_event

on_event(state: State, event: Event, **kwargs)

on_event 在事件触发时调用。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发更新的事件

  • kwargs

    更新模型的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_event(self, state: State, event: Event, **kwargs):
    """
    on_event is called whenever an event is triggered

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

on_finalize

on_finalize(state: State, **kwargs) -> bool

on_finalize 在修改器最终化时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    最终化修改器的附加参数

返回

  • bool

    如果修改器成功最终化则为 True,否则为 False

源代码位于 llmcompressor/modifiers/modifier.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    on_finalize is called on modifier finalization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    :return: True if the modifier was finalized successfully,
        False otherwise
    """
    return True

on_initialize abstractmethod

on_initialize(state: State, **kwargs) -> bool

on_initialize 在修改器初始化时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    初始化修改器的附加参数

返回

  • bool

    如果修改器成功初始化则为 True,否则为 False

源代码位于 llmcompressor/modifiers/modifier.py
@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    on_initialize is called on modifier initialization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    :return: True if the modifier was initialized successfully,
        False otherwise
    """
    raise NotImplementedError()

on_start

on_start(state: State, event: Event, **kwargs)

on_start 在修改器开始时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发开始的事件

  • kwargs

    开始修改器的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    on_start is called when the modifier starts and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the start
    :param kwargs: Additional arguments for starting the modifier
    """
    pass

on_update

on_update(state: State, event: Event, **kwargs)

on_update 在相关模型必须根据传入事件进行更新时调用。必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发更新的事件

  • kwargs

    更新模型的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_update(self, state: State, event: Event, **kwargs):
    """
    on_update is called when the model in question must be
    updated based on passed in event. Must be implemented by the
    inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

should_end

should_end(event: Event)

参数

  • event

    (Event) –

    检查修改器是否应该结束的事件

返回

  • 如果修改器应根据给定事件结束,则为 True

源代码位于 llmcompressor/modifiers/modifier.py
def should_end(self, event: Event):
    """
    :param event: The event to check if the modifier should end
    :return: True if the modifier should end based on the given event
    """
    current = event.current_index

    return self.end is not None and current >= self.end

should_start

should_start(event: Event) -> bool

参数

  • event

    (Event) –

    检查修改器是否应该开始的事件

返回

  • bool

    如果修改器应根据给定事件开始,则为 True

源代码位于 llmcompressor/modifiers/modifier.py
def should_start(self, event: Event) -> bool:
    """
    :param event: The event to check if the modifier should start
    :return: True if the modifier should start based on the given event
    """
    if self.start is None:
        return False

    current = event.current_index

    return self.start <= current and (self.end is None or current < self.end)

update_event

update_event(state: State, event: Event, **kwargs)

根据给定事件更新修改器。进而根据事件和修改器设置调用 on_start、on_update 和 on_end。如果修改器未初始化则立即返回。

参数

  • state

    (State) –

    稀疏化的当前状态

  • event

    (Event) –

    用于更新修饰符的事件

  • kwargs

    更新修改器的附加参数

引发

  • RuntimeError

    如果修改器已最终化

源代码位于 llmcompressor/modifiers/modifier.py
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update modifier based on the given event. In turn calls
    on_start, on_update, and on_end based on the event and
    modifier settings. Returns immediately if the modifier is
    not initialized

    :raises RuntimeError: if the modifier has been finalized
    :param state: The current state of sparsification
    :param event: The event to update the modifier with
    :param kwargs: Additional arguments for updating the modifier
    """
    if not self.initialized_:
        raise RuntimeError("Cannot update an uninitialized modifier")

    if self.finalized_:
        raise RuntimeError("Cannot update a finalized modifier")

    self.on_event(state, event, **kwargs)

    # handle starting the modifier if needed
    if (
        event.type_ == EventType.BATCH_START
        and not self.started_
        and self.should_start(event)
    ):
        self.on_start(state, event, **kwargs)
        self.started_ = True
        self.on_update(state, event, **kwargs)

        return

    # handle ending the modifier if needed
    if (
        event.type_ == EventType.BATCH_END
        and not self.ended_
        and self.should_end(event)
    ):
        self.on_end(state, event, **kwargs)
        self.ended_ = True
        self.on_update(state, event, **kwargs)

        return

    if self.started_ and not self.ended_:
        self.on_update(state, event, **kwargs)

ModifierFactory

用于加载和注册修改器的工厂

方法

  • create

    从已注册的修改器中实例化给定类型的修改器。

  • load_from_package

    :param package_path: 要从中加载修改器的包的路径

  • refresh

    一个通过重新加载修改器来刷新工厂的方法

  • register

    注册一个修改器类供工厂使用。

create staticmethod

create(
    type_: str,
    allow_registered: bool,
    allow_experimental: bool,
    **kwargs,
) -> Modifier

从已注册的修改器中实例化给定类型的修改器。

参数

  • 类型_

    (str) –

    要创建的修改器的类型

  • framework

    修改器适用的框架

  • allow_registered

    (bool) –

    是否允许已注册的修改器

  • allow_experimental

    (bool) –

    是否允许实验性修改器

  • kwargs

    在实例化期间传递给修改器的其他关键字参数

返回

引发

  • ValueError

    如果找不到给定类型的修改器

源代码在 llmcompressor/modifiers/factory.py
@staticmethod
def create(
    type_: str,
    allow_registered: bool,
    allow_experimental: bool,
    **kwargs,
) -> Modifier:
    """
    Instantiate a modifier of the given type from registered modifiers.

    :raises ValueError: If no modifier of the given type is found
    :param type_: The type of modifier to create
    :param framework: The framework the modifier is for
    :param allow_registered: Whether or not to allow registered modifiers
    :param allow_experimental: Whether or not to allow experimental modifiers
    :param kwargs: Additional keyword arguments to pass to the modifier
        during instantiation
    :return: The instantiated modifier
    """
    if type_ in ModifierFactory._errors:
        raise ModifierFactory._errors[type_]

    if type_ in ModifierFactory._registered_registry:
        if allow_registered:
            return ModifierFactory._registered_registry[type_](**kwargs)
        else:
            # TODO: log warning that modifier was skipped
            pass

    if type_ in ModifierFactory._experimental_registry:
        if allow_experimental:
            return ModifierFactory._experimental_registry[type_](**kwargs)
        else:
            # TODO: log warning that modifier was skipped
            pass

    if type_ in ModifierFactory._main_registry:
        return ModifierFactory._main_registry[type_](**kwargs)

    raise ValueError(f"No modifier of type '{type_}' found.")

load_from_package staticmethod

load_from_package(
    package_path: str,
) -> Dict[str, Type[Modifier]]

参数

  • package_path

    (str) –

    要从中加载修改器的包的路径

返回

  • Dict[str, Type[Modifier]]

    已加载的修改器,作为名称到类的映射

源代码在 llmcompressor/modifiers/factory.py
@staticmethod
def load_from_package(package_path: str) -> Dict[str, Type[Modifier]]:
    """
    :param package_path: The path to the package to load modifiers from
    :return: The loaded modifiers, as a mapping of name to class
    """
    loaded = {}
    main_package = importlib.import_module(package_path)

    # exclude deprecated packages from registry so
    # their new location is used instead
    deprecated_packages = [
        "llmcompressor.modifiers.obcq",
        "llmcompressor.modifiers.obcq.sgpt_base",
    ]
    for _importer, modname, _is_pkg in pkgutil.walk_packages(
        main_package.__path__, package_path + "."
    ):
        if modname in deprecated_packages:
            continue
        try:
            module = importlib.import_module(modname)

            for attribute_name in dir(module):
                if not attribute_name.endswith("Modifier"):
                    continue

                try:
                    if attribute_name in loaded:
                        continue

                    attr = getattr(module, attribute_name)

                    if not isinstance(attr, type):
                        raise ValueError(
                            f"Attribute {attribute_name} is not a type"
                        )

                    if not issubclass(attr, Modifier):
                        raise ValueError(
                            f"Attribute {attribute_name} is not a Modifier"
                        )

                    loaded[attribute_name] = attr
                except Exception as err:
                    # TODO: log import error
                    ModifierFactory._errors[attribute_name] = err
        except Exception as module_err:
            # TODO: log import error
            print(module_err)

    return loaded

refresh staticmethod

refresh()

一个通过重新加载修改器来刷新工厂的方法。注意:这将清除所有先前注册的修改器。

源代码在 llmcompressor/modifiers/factory.py
@staticmethod
def refresh():
    """
    A method to refresh the factory by reloading the modifiers
    Note: this will clear any previously registered modifiers
    """
    ModifierFactory._main_registry = ModifierFactory.load_from_package(
        ModifierFactory._MAIN_PACKAGE_PATH
    )
    ModifierFactory._experimental_registry = ModifierFactory.load_from_package(
        ModifierFactory._EXPERIMENTAL_PACKAGE_PATH
    )
    ModifierFactory._loaded = True

register staticmethod

register(type_: str, modifier_class: Type[Modifier])

注册一个修改器类供工厂使用。

参数

  • 类型_

    (str) –

    要注册的修改器的类型

  • modifier_class

    (Type[Modifier]) –

    要注册的修改器的类,必须继承自 Modifier 基类

引发

  • ValueError

    如果提供的类未继承自 Modifier 基类或不是一个类型

源代码在 llmcompressor/modifiers/factory.py
@staticmethod
def register(type_: str, modifier_class: Type[Modifier]):
    """
    Register a modifier class to be used by the factory.

    :raises ValueError: If the provided class does not subclass the Modifier
        base class or is not a type
    :param type_: The type of modifier to register
    :param modifier_class: The class of the modifier to register, must subclass
        the Modifier base class
    """
    if not issubclass(modifier_class, Modifier):
        raise ValueError(
            "The provided class does not subclass the Modifier base class."
        )
    if not isinstance(modifier_class, type):
        raise ValueError("The provided class is not a type.")

    ModifierFactory._registered_registry[type_] = modifier_class

ModifierInterface

基类:ABC

定义所有修改器必须实现的契约

方法

属性

  • finalized (bool) –

    :return: 如果修饰符已最终确定,则为 True

  • initialized (bool) –

    :return: 如果修饰符已初始化,则为 True

finalized abstractmethod property

finalized: bool

返回

  • bool

    如果修饰符已最终确定,则为 True

initialized abstractmethod property

initialized: bool

返回

  • bool

    如果修饰符已初始化,则为 True

finalize abstractmethod

finalize(state: State, **kwargs)

完成修改器

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    用于修改器最终化的附加关键字参数

源代码在 llmcompressor/modifiers/interface.py
@abstractmethod
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier

    :param state: The current state of the model
    :param kwargs: Additional keyword arguments for
        modifier finalization
    """
    raise NotImplementedError()

initialize abstractmethod

initialize(state: State, **kwargs)

初始化修改器

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    用于修改器初始化的附加关键字参数

源代码在 llmcompressor/modifiers/interface.py
@abstractmethod
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier

    :param state: The current state of the model
    :param kwargs: Additional keyword arguments
        for modifier initialization
    """
    raise NotImplementedError()

update_event abstractmethod

update_event(state: State, event: Event, **kwargs)

根据事件更新修改器

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    用于更新修饰符的事件

  • kwargs

    用于修改器更新的附加关键字参数

源代码在 llmcompressor/modifiers/interface.py
@abstractmethod
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update the modifier based on the event

    :param state: The current state of the model
    :param event: The event to update the modifier with
    :param kwargs: Additional keyword arguments for
        modifier update
    """
    raise NotImplementedError()