跳到内容

llmcompressor.recipe.utils

函数

append_recipe_dict

append_recipe_dict(d1: dict, d2: dict) -> dict

通过重命名顶层阶段键为编号版本来合并两个 recipe dict。

如果两个 dict 具有相同的阶段键(例如 'test_stage'),则结果将包含:'test_stage_0'、'test_stage_1' 等。

即使是第一次出现,也总是从 0 开始编号。

源代码在 llmcompressor/recipe/utils.py
def append_recipe_dict(d1: dict, d2: dict) -> dict:
    """
    Merge two recipe dicts by renaming top-level stage keys to numbered versions.

    If both have the same stage key (e.g. 'test_stage'), the result will contain:
        'test_stage_0', 'test_stage_1', etc.

    Always starts numbering from 0 even for the first occurrence.
    """
    result = dict(d1)
    for key, val in d2.items():
        if key not in result:
            result[key] = val
        else:
            # Stage key conflict — apply suffixes to both entries
            base_key = re.sub(r"_\d+$", "", key)

            # Rename original if not yet renamed
            if key == base_key:
                result[f"{base_key}_0"] = result.pop(key)
                result[f"{base_key}_1"] = val
            else:
                # Key was already suffixed, find next free index
                i = 1
                while f"{base_key}_{i}" in result:
                    i += 1
                result[f"{base_key}_{i}"] = val
    return result

filter_dict

filter_dict(
    obj: dict, target_stage: Optional[str] = None
) -> dict

过滤字典,只包含与目标阶段匹配的键。

参数

  • obj

    (dict) –

    需要过滤的 recipe 字典。

  • target_stage

    (Optional[str], 默认值: None ) –

    用于过滤的阶段(例如 'test_stage')。

返回

  • dict

    仅包含与目标阶段匹配的键的新字典。

源代码在 llmcompressor/recipe/utils.py
def filter_dict(obj: dict, target_stage: Optional[str] = None) -> dict:
    """
    Filter a dictionary to only include keys that match the target stage.

    :param obj: The recipe dictionary to filter.
    :param target_stage: The stage to filter by (e.g., 'test_stage').
    :return: A new dictionary containing only the keys that match the target stage.
    """
    if not target_stage:
        return obj
    return {k: v for k, v in obj.items() if k.startswith(target_stage)}

get_yaml_serializable_dict

get_yaml_serializable_dict(
    modifiers: List[Modifier], stage: str
) -> Dict[str, Any]

此函数用于将修饰符列表转换为字典,其中键是组名,值是修饰符,而修饰符本身是字典,其键是修饰符类型,值是修饰符参数。这是为了在 YAML 序列化期间符合我们的 recipe 结构,其中每个阶段、修饰符组和修饰符都表示为有效的 YAML 字典。

注意:此函数假定修饰符组在组内不包含相同的修饰符类型多次。Recipe.create_instance(...) 方法也持有此假设。

参数

  • 修饰符

    (List[Modifier]) –

    包含有关修饰符所有信息的字典列表

返回

  • Dict[str, Any]

    一个字典,其中键是组名,值是修饰符,而修饰符本身是字典,其键是修饰符类型,值是修饰符参数。

源代码在 llmcompressor/recipe/utils.py
def get_yaml_serializable_dict(modifiers: List[Modifier], stage: str) -> Dict[str, Any]:
    """
    This function is used to convert a list of modifiers into a dictionary
    where the keys are the group names and the values are the modifiers
    which in turn are dictionaries with the modifier type as the key and
    the modifier args as the value.
    This is needed to conform to our recipe structure during yaml serialization
    where each stage, modifier_groups, and modifiers are represented as
    valid yaml dictionaries.

    Note: This function assumes that modifier groups do not contain the same
    modifier type more than once in a group. This assumption is also held by
    Recipe.create_instance(...) method.

    :param modifiers: A list of dictionaries where each dictionary
        holds all information about a modifier
    :return: A dictionary where the keys are the group names and the values
        are the modifiers which in turn are dictionaries with the modifier
        type as the key and the modifier args as the value.
    """

    stage_dict = {}
    stage_name = stage + "_stage"
    stage_dict[stage_name] = {}
    for modifier in modifiers:
        group = getattr(modifier, "group", stage) or stage
        group_name = f"{group}_modifiers"
        modifier_type = modifier.__class__.__name__

        args = {
            k: v
            for k, v in modifier.model_dump().items()
            if v is not None and not k.endswith("_") and k != "group"
        }

        if group_name not in stage_dict[stage_name]:
            stage_dict[stage_name][group_name] = {}

        stage_dict[stage_name][group_name][modifier_type] = args

    return stage_dict