Bases: ModifierInterface, HooksMixin
所有修改器继承的基类。修改器用于修改模型的训练过程。定义了所有修改器可用的基本属性和方法。
生命周期: 1. 初始化 2. on_event -> * 如果 self.start <= event.current_index 则 on_start * 如果 self.end >= event.current_index 则 on_end 5. 最终化
参数
-
index
– -
group
– -
start
– -
结束
– -
更新
–
方法
属性
finalize
finalize(state: State, **kwargs)
为给定模型和状态最终化修改器。
参数
引发
源代码位于 llmcompressor/modifiers/modifier.py
| def finalize(self, state: State, **kwargs):
"""
Finalize the modifier for the given model and state.
:raises RuntimeError: if the modifier has not been initialized
:param state: The current state of the model
:param kwargs: Additional arguments for finalizing the modifier
"""
if self.finalized_:
raise RuntimeError("cannot finalize a modifier twice")
if not self.initialized_:
raise RuntimeError("cannot finalize an uninitialized modifier")
# TODO: all finalization should succeed
self.finalized_ = self.on_finalize(state=state, **kwargs)
|
initialize
initialize(state: State, **kwargs)
为给定模型和状态初始化修改器。
参数
引发
源代码位于 llmcompressor/modifiers/modifier.py
| def initialize(self, state: State, **kwargs):
"""
Initialize the modifier for the given model and state.
:raises RuntimeError: if the modifier has already been finalized
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the modifier
"""
if self.initialized_:
raise RuntimeError(
"Cannot initialize a modifier that has already been initialized"
)
if self.finalized_:
raise RuntimeError(
"Cannot initialize a modifier that has already been finalized"
)
self.initialized_ = self.on_initialize(state=state, **kwargs)
# trigger starts
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
if self.should_start(fake_start_event):
self.on_start(state, fake_start_event, **kwargs)
self.started_ = True
|
on_end
on_end(state: State, event: Event, **kwargs)
on_end 在修改器结束时调用,必须由继承的修改器实现。
参数
源代码位于 llmcompressor/modifiers/modifier.py
| def on_end(self, state: State, event: Event, **kwargs):
"""
on_end is called when the modifier ends and must be implemented
by the inheriting modifier.
:param state: The current state of the model
:param event: The event that triggered the end
:param kwargs: Additional arguments for ending the modifier
"""
pass
|
on_event
on_event(state: State, event: Event, **kwargs)
on_event 在事件触发时调用。
参数
源代码位于 llmcompressor/modifiers/modifier.py
| def on_event(self, state: State, event: Event, **kwargs):
"""
on_event is called whenever an event is triggered
:param state: The current state of the model
:param event: The event that triggered the update
:param kwargs: Additional arguments for updating the model
"""
pass
|
on_finalize
on_finalize(state: State, **kwargs) -> bool
on_finalize 在修改器最终化时调用,必须由继承的修改器实现。
参数
返回
-
bool – 如果修改器成功最终化则为 True,否则为 False
源代码位于 llmcompressor/modifiers/modifier.py
| def on_finalize(self, state: State, **kwargs) -> bool:
"""
on_finalize is called on modifier finalization and
must be implemented by the inheriting modifier.
:param state: The current state of the model
:param kwargs: Additional arguments for finalizing the modifier
:return: True if the modifier was finalized successfully,
False otherwise
"""
return True
|
on_initialize abstractmethod
on_initialize(state: State, **kwargs) -> bool
on_initialize 在修改器初始化时调用,必须由继承的修改器实现。
参数
返回
-
bool – 如果修改器成功初始化则为 True,否则为 False
源代码位于 llmcompressor/modifiers/modifier.py
| @abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
"""
on_initialize is called on modifier initialization and
must be implemented by the inheriting modifier.
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the modifier
:return: True if the modifier was initialized successfully,
False otherwise
"""
raise NotImplementedError()
|
on_start
on_start(state: State, event: Event, **kwargs)
on_start 在修改器开始时调用,必须由继承的修改器实现。
参数
源代码位于 llmcompressor/modifiers/modifier.py
| def on_start(self, state: State, event: Event, **kwargs):
"""
on_start is called when the modifier starts and
must be implemented by the inheriting modifier.
:param state: The current state of the model
:param event: The event that triggered the start
:param kwargs: Additional arguments for starting the modifier
"""
pass
|
on_update
on_update(state: State, event: Event, **kwargs)
on_update 在相关模型必须根据传入事件进行更新时调用。必须由继承的修改器实现。
参数
源代码位于 llmcompressor/modifiers/modifier.py
| def on_update(self, state: State, event: Event, **kwargs):
"""
on_update is called when the model in question must be
updated based on passed in event. Must be implemented by the
inheriting modifier.
:param state: The current state of the model
:param event: The event that triggered the update
:param kwargs: Additional arguments for updating the model
"""
pass
|
should_end
参数
返回
源代码位于 llmcompressor/modifiers/modifier.py
| def should_end(self, event: Event):
"""
:param event: The event to check if the modifier should end
:return: True if the modifier should end based on the given event
"""
current = event.current_index
return self.end is not None and current >= self.end
|
should_start
should_start(event: Event) -> bool
参数
返回
源代码位于 llmcompressor/modifiers/modifier.py
| def should_start(self, event: Event) -> bool:
"""
:param event: The event to check if the modifier should start
:return: True if the modifier should start based on the given event
"""
if self.start is None:
return False
current = event.current_index
return self.start <= current and (self.end is None or current < self.end)
|
update_event
update_event(state: State, event: Event, **kwargs)
根据给定事件更新修改器。进而根据事件和修改器设置调用 on_start、on_update 和 on_end。如果修改器未初始化则立即返回。
参数
引发
源代码位于 llmcompressor/modifiers/modifier.py
| def update_event(self, state: State, event: Event, **kwargs):
"""
Update modifier based on the given event. In turn calls
on_start, on_update, and on_end based on the event and
modifier settings. Returns immediately if the modifier is
not initialized
:raises RuntimeError: if the modifier has been finalized
:param state: The current state of sparsification
:param event: The event to update the modifier with
:param kwargs: Additional arguments for updating the modifier
"""
if not self.initialized_:
raise RuntimeError("Cannot update an uninitialized modifier")
if self.finalized_:
raise RuntimeError("Cannot update a finalized modifier")
self.on_event(state, event, **kwargs)
# handle starting the modifier if needed
if (
event.type_ == EventType.BATCH_START
and not self.started_
and self.should_start(event)
):
self.on_start(state, event, **kwargs)
self.started_ = True
self.on_update(state, event, **kwargs)
return
# handle ending the modifier if needed
if (
event.type_ == EventType.BATCH_END
and not self.ended_
and self.should_end(event)
):
self.on_end(state, event, **kwargs)
self.ended_ = True
self.on_update(state, event, **kwargs)
return
if self.started_ and not self.ended_:
self.on_update(state, event, **kwargs)
|