跳到内容

llmcompressor.modifiers.modifier

  • Modifier

    所有修改器继承的基类。

Modifier

Bases: ModifierInterface, HooksMixin

所有修改器继承的基类。修改器用于修改模型的训练过程。定义了所有修改器可用的基本属性和方法。

生命周期: 1. 初始化 2. on_event -> * 如果 self.start <= event.current_index 则 on_start * 如果 self.end >= event.current_index 则 on_end 5. 最终化

参数

  • index

    修改器在模型修改器列表中的索引

  • group

    修改器的组名

  • start

    修改器的起始步长

  • 结束

    修改器的结束步长

  • 更新

    修改器的更新步长

方法

  • finalize

    为给定模型和状态最终化修改器。

  • initialize

    为给定模型和状态初始化修改器。

  • on_end

    on_end 在修改器结束时调用,必须实现。

  • on_event

    on_event 在事件触发时调用。

  • on_finalize

    on_finalize 在修改器最终化时调用。

  • on_initialize

    on_initialize 在修改器初始化时调用。

  • on_start

    on_start 在修改器开始时调用。

  • on_update

    on_update 在模型需要更新时调用。

  • should_end

    :param event: 检查修改器是否应该结束的事件

  • should_start

    :param event: 检查修改器是否应该开始的事件

  • update_event

    根据给定事件更新修改器。进而调用

属性

  • finalized (bool) –

    :return: 如果修饰符已最终确定,则为 True

  • initialized (bool) –

    :return: 如果修饰符已初始化,则为 True

finalized property

finalized: bool

返回

  • bool

    如果修饰符已最终确定,则为 True

initialized property

initialized: bool

返回

  • bool

    如果修饰符已初始化,则为 True

finalize

finalize(state: State, **kwargs)

为给定模型和状态最终化修改器。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    最终化修改器的附加参数

引发

  • RuntimeError

    如果修改器未初始化

源代码位于 llmcompressor/modifiers/modifier.py
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has not been initialized
    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    """
    if self.finalized_:
        raise RuntimeError("cannot finalize a modifier twice")

    if not self.initialized_:
        raise RuntimeError("cannot finalize an uninitialized modifier")

    # TODO: all finalization should succeed
    self.finalized_ = self.on_finalize(state=state, **kwargs)

initialize

initialize(state: State, **kwargs)

为给定模型和状态初始化修改器。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    初始化修改器的附加参数

引发

  • RuntimeError

    如果修改器已最终化

源代码位于 llmcompressor/modifiers/modifier.py
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has already been finalized
    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    """
    if self.initialized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been initialized"
        )

    if self.finalized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been finalized"
        )

    self.initialized_ = self.on_initialize(state=state, **kwargs)

    # trigger starts
    fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
    if self.should_start(fake_start_event):
        self.on_start(state, fake_start_event, **kwargs)
        self.started_ = True

on_end

on_end(state: State, event: Event, **kwargs)

on_end 在修改器结束时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发结束的事件

  • kwargs

    结束修改器的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    on_end is called when the modifier ends and must be implemented
    by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the end
    :param kwargs: Additional arguments for ending the modifier
    """
    pass

on_event

on_event(state: State, event: Event, **kwargs)

on_event 在事件触发时调用。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发更新的事件

  • kwargs

    更新模型的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_event(self, state: State, event: Event, **kwargs):
    """
    on_event is called whenever an event is triggered

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

on_finalize

on_finalize(state: State, **kwargs) -> bool

on_finalize 在修改器最终化时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    最终化修改器的附加参数

返回

  • bool

    如果修改器成功最终化则为 True,否则为 False

源代码位于 llmcompressor/modifiers/modifier.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    on_finalize is called on modifier finalization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    :return: True if the modifier was finalized successfully,
        False otherwise
    """
    return True

on_initialize abstractmethod

on_initialize(state: State, **kwargs) -> bool

on_initialize 在修改器初始化时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • kwargs

    初始化修改器的附加参数

返回

  • bool

    如果修改器成功初始化则为 True,否则为 False

源代码位于 llmcompressor/modifiers/modifier.py
@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    on_initialize is called on modifier initialization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    :return: True if the modifier was initialized successfully,
        False otherwise
    """
    raise NotImplementedError()

on_start

on_start(state: State, event: Event, **kwargs)

on_start 在修改器开始时调用,必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发开始的事件

  • kwargs

    开始修改器的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    on_start is called when the modifier starts and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the start
    :param kwargs: Additional arguments for starting the modifier
    """
    pass

on_update

on_update(state: State, event: Event, **kwargs)

on_update 在相关模型必须根据传入事件进行更新时调用。必须由继承的修改器实现。

参数

  • state

    (State) –

    模型的当前状态

  • event

    (Event) –

    触发更新的事件

  • kwargs

    更新模型的附加参数

源代码位于 llmcompressor/modifiers/modifier.py
def on_update(self, state: State, event: Event, **kwargs):
    """
    on_update is called when the model in question must be
    updated based on passed in event. Must be implemented by the
    inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

should_end

should_end(event: Event)

参数

  • event

    (Event) –

    检查修改器是否应该结束的事件

返回

  • 如果修改器应根据给定事件结束,则为 True

源代码位于 llmcompressor/modifiers/modifier.py
def should_end(self, event: Event):
    """
    :param event: The event to check if the modifier should end
    :return: True if the modifier should end based on the given event
    """
    current = event.current_index

    return self.end is not None and current >= self.end

should_start

should_start(event: Event) -> bool

参数

  • event

    (Event) –

    检查修改器是否应该开始的事件

返回

  • bool

    如果修改器应根据给定事件开始,则为 True

源代码位于 llmcompressor/modifiers/modifier.py
def should_start(self, event: Event) -> bool:
    """
    :param event: The event to check if the modifier should start
    :return: True if the modifier should start based on the given event
    """
    if self.start is None:
        return False

    current = event.current_index

    return self.start <= current and (self.end is None or current < self.end)

update_event

update_event(state: State, event: Event, **kwargs)

根据给定事件更新修改器。进而根据事件和修改器设置调用 on_start、on_update 和 on_end。如果修改器未初始化则立即返回。

参数

  • state

    (State) –

    稀疏化的当前状态

  • event

    (Event) –

    用于更新修饰符的事件

  • kwargs

    更新修改器的附加参数

引发

  • RuntimeError

    如果修改器已最终化

源代码位于 llmcompressor/modifiers/modifier.py
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update modifier based on the given event. In turn calls
    on_start, on_update, and on_end based on the event and
    modifier settings. Returns immediately if the modifier is
    not initialized

    :raises RuntimeError: if the modifier has been finalized
    :param state: The current state of sparsification
    :param event: The event to update the modifier with
    :param kwargs: Additional arguments for updating the modifier
    """
    if not self.initialized_:
        raise RuntimeError("Cannot update an uninitialized modifier")

    if self.finalized_:
        raise RuntimeError("Cannot update a finalized modifier")

    self.on_event(state, event, **kwargs)

    # handle starting the modifier if needed
    if (
        event.type_ == EventType.BATCH_START
        and not self.started_
        and self.should_start(event)
    ):
        self.on_start(state, event, **kwargs)
        self.started_ = True
        self.on_update(state, event, **kwargs)

        return

    # handle ending the modifier if needed
    if (
        event.type_ == EventType.BATCH_END
        and not self.ended_
        and self.should_end(event)
    ):
        self.on_end(state, event, **kwargs)
        self.ended_ = True
        self.on_update(state, event, **kwargs)

        return

    if self.started_ and not self.ended_:
        self.on_update(state, event, **kwargs)