跳到内容

llmcompressor.utils

LLM Compressor 中使用的通用实用函数。

模块

  • NumpyArrayBatcher

    Batcher 实例,用于处理 numpy 数组字典的输入,将多个项目附加到它们上面以增加批次大小,然后将它们堆叠成所有字典键的单个批处理 numpy 数组。

函数

  • DisableQuantization

    在应用量化配置后,在前向传递期间禁用量化。

  • bucket_iterable

    将可迭代对象分桶到子数组中,该子数组由前 top 百分比的元素组成。

  • calibration_forward_context

    所有校准前向传递应发生的上下文。

  • clean_path

    :param path: 要清理的目录或文件路径。

  • convert_to_bool

    :param val: 要转换为布尔值的参数。

  • create_dirs

    :param path: 要尝试创建的目录路径。

  • create_parent_dirs

    :param path: 要为其尝试创建父目录的文件路径。

  • create_unique_dir

    :param path: 要创建唯一版本的路径。

  • disable_cache

    暂时禁用 Transformer 模型的键值缓存。用于防止在仅执行预填充阶段而不执行生成阶段的一次性情况下占用过多内存。

  • disable_hf_kernels

    在 transformers>=4.50.0 中,某些模块的前向方法可能会被 hf hub 内核调用替换。这可能会绕过 LLM Compressor 添加的钩子。

  • disable_lm_head

    通过将 lm_head 移动到 meta 设备来禁用模型的 lm_head。此函数不会解开参数,并在退出时恢复模型的正确加载。

  • dispatch_for_generation

    分派模型的自回归生成。这意味着模块将被平均分派到可用设备上,并尽可能保持加载状态。移除任何先前可能存在的 HF 钩子。

  • eval_context

    禁用给定模块的 PyTorch 训练模式。

  • flatten_iterable

    :param li: 要展平的可能嵌套的可迭代对象。

  • get_embeddings

    返回模型的输入和输出嵌入。如果 get_input_embeddings/。

  • getattr_chain

    链式调用多个 getattr 调用,用 . 分隔。

  • import_from_path

    导入由 . 分隔的模块和函数/类名。

  • interpolate

    注意,将值限制在 x0 的最小值和 x1 的最大值,

  • interpolate_list_linear

    线性插值测量值列表内的输入值。

  • interpolated_integral

    计算一组形式为 [(x0, y0), (x1, y1), ...] 的测量的插值积分。

  • is_package_available

    检查包是否可用的辅助函数。

  • is_url

    :param val: 要检查的值是否为 url。

  • json_to_jsonl

    将 json 列表文件转换为 jsonl 文件格式(用于分片效率)。

  • load_labeled_data

    从磁盘或内存加载标签和数据并将它们组合在一起。假定磁盘上的排序顺序。当为 data 和/或 labels 提供文件 glob 时,它们将匹配。

  • load_numpy

    将 numpy 文件加载到 ndarray 或 OrderedDict 中,表示 npz 文件中的内容。

  • patch_attr

    修补对象属性的值。原始值在退出时恢复。

  • patch_transformers_logger_level

    修改 transformers 日志记录器级别的上下文。

  • path_file_count

    返回给定路径下匹配给定模式的文件数。

  • path_file_size

    返回文件系统上给定路径的总大小(以字节为单位)。

  • save_numpy

    将 numpy 数组或 numpy 数组集合保存到磁盘。

  • skip_weights_download

    在此上下文中初始化模型,而无需下载。

  • targets_embeddings

    返回给定目标是否针对模型词嵌入。

  • tensor_export

    :param tensor: 要导出到已保存 numpy 数组文件的张量。

  • tensors_export

    :param tensors: 要导出到已保存 numpy 数组文件的张量。

  • untie_word_embeddings

    如果可能,解绑词嵌入。如果模型定义中找不到嵌入,此函数将发出警告。

  • validate_str_iterable

    :param val: 要验证的值,检查它是否是列表(并将其展平)。

NumpyArrayBatcher

NumpyArrayBatcher()

Bases: object

Batcher 实例,用于处理 numpy 数组字典的输入,将多个项目附加到它们上面以增加批次大小,然后将它们堆叠成所有字典键的单个批处理 numpy 数组。

方法

  • append

    将新项目附加到当前批次中。

  • stack

    将当前项目堆叠成一个批次,沿着新的、零填充的维度。

Source code in llmcompressor/utils/helpers.py
def __init__(self):
    self._items = OrderedDict()  # type: Dict[str, List[numpy.ndarray]]

append

append(item: Union[ndarray, Dict[str, ndarray]])

将新项目附加到当前批次中。所有键和形状必须与当前状态匹配。

参数

  • item

    (Union[ndarray, Dict[str, ndarray]]) –

    要添加到批次的项。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def append(self, item: Union[numpy.ndarray, Dict[str, numpy.ndarray]]):
    """
    Append a new item into the current batch.
    All keys and shapes must match the current state.

    :param item: the item to add for batching
    """
    if len(self) < 1 and isinstance(item, numpy.ndarray):
        self._items[NDARRAY_KEY] = [item]
    elif len(self) < 1:
        for key, val in item.items():
            self._items[key] = [val]
    elif isinstance(item, numpy.ndarray):
        if NDARRAY_KEY not in self._items:
            raise ValueError(
                "numpy ndarray passed for item, but prev_batch does not contain one"
            )

        if item.shape != self._items[NDARRAY_KEY][0].shape:
            raise ValueError(
                (
                    "item of numpy ndarray of shape {} does not "
                    "match the current batch shape of {}".format(
                        item.shape, self._items[NDARRAY_KEY][0].shape
                    )
                )
            )

        self._items[NDARRAY_KEY].append(item)
    else:
        diff_keys = list(set(item.keys()) - set(self._items.keys()))

        if len(diff_keys) > 0:
            raise ValueError(
                (
                    "numpy dict passed for item, not all keys match "
                    "with the prev_batch. difference: {}"
                ).format(diff_keys)
            )

        for key, val in item.items():
            if val.shape != self._items[key][0].shape:
                raise ValueError(
                    (
                        "item with key {} of shape {} does not "
                        "match the current batch shape of {}".format(
                            key, val.shape, self._items[key][0].shape
                        )
                    )
                )

            self._items[key].append(val)

stack

stack() -> Dict[str, numpy.ndarray]

将当前项目堆叠成一个批次,沿着新的、零填充的维度。

返回

  • Dict[str, ndarray]

    堆叠的项。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def stack(self) -> Dict[str, numpy.ndarray]:
    """
    Stack the current items into a batch along a new, zeroed dimension

    :return: the stacked items
    """
    batch_dict = OrderedDict()

    for key, val in self._items.items():
        batch_dict[key] = numpy.stack(self._items[key])

    return batch_dict

DisableQuantization

DisableQuantization(module: Module)

在应用量化配置后,在前向传递期间禁用量化。

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def DisableQuantization(module: torch.nn.Module):
    """
    Disable quantization during forward passes after applying a quantization config
    """
    try:
        module.apply(disable_quantization)
        yield
    finally:
        module.apply(enable_quantization)

bucket_iterable

bucket_iterable(
    val: Iterable[Any],
    num_buckets: int = 3,
    edge_percent: float = 0.05,
    sort_highest: bool = True,
    sort_key: Callable[[Any], Any] = None,
) -> List[Tuple[int, Any]]

将可迭代对象分桶到子数组中,该子数组由前 top 百分比的元素组成,然后将可迭代对象的其余部分切分成相等大小的块。

参数

  • 验证

    (Iterable[Any]) –

    要分桶的可迭代对象。

  • num_buckets

    (int, default: 3 ) –

    要将可迭代对象分组的桶数,不包括 top 桶。

  • edge_percent

    (float, default: 0.05 ) –

    将前百分比的元素放入自己的桶中。如果 sort_highest 为 True,则这是 top 百分比;否则是 bottom 百分比。如果 <= 0,则不会创建边缘桶。

  • sort_highest

    (bool, 默认值: True ) –

    True 表示按降序排序,使最高百分比的元素在前,并按降序创建桶。False 表示按升序排序,使最低百分比的元素在前,并按升序创建桶。

  • sort_key

    (Callable[[Any], Any], default: None ) –

    用于对可迭代对象进行排序的 sort_key,如果已将其转换为列表。

返回

  • List[Tuple[int, Any]]

    一个列表,其中每个值都映射到它被排序到的桶。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def bucket_iterable(
    val: Iterable[Any],
    num_buckets: int = 3,
    edge_percent: float = 0.05,
    sort_highest: bool = True,
    sort_key: Callable[[Any], Any] = None,
) -> List[Tuple[int, Any]]:
    """
    Bucket iterable into subarray consisting of the first top percentage
    followed by the rest of the iterable sliced into equal sliced groups.

    :param val: The iterable to bucket
    :param num_buckets: The number of buckets to group the iterable into,
        does not include the top bucket
    :param edge_percent: Group the first percent into its own bucket.
        If sort_highest, then this is the top percent, else bottom percent.
        If <= 0, then will not create an edge bucket
    :param sort_highest: True to sort such that the highest percent is first
        and will create buckets in descending order.
        False to sort so lowest is first and create buckets in ascending order.
    :param sort_key: The sort_key, if any, to use for sorting the iterable
        after converting it to a list
    :return: a list of each value mapped to the bucket it was sorted into
    """

    val_list = [v for v in val]
    val_list.sort(key=sort_key, reverse=sort_highest)
    bucketed_values = []
    edge_count = round(edge_percent * len(val_list))

    if edge_count > 0:
        bucketed_values.extend([(-1, val) for val in val_list[:edge_count]])
        val_list = val_list[edge_count:]

    buckets_count = round(len(val_list) / float(num_buckets))

    for bucket in range(num_buckets):
        add_vals = val_list[:buckets_count] if bucket < num_buckets - 1 else val_list
        val_list = val_list[buckets_count:] if bucket < num_buckets - 1 else []
        bucketed_values.extend([(bucket, val) for val in add_vals])

    return bucketed_values

calibration_forward_context

calibration_forward_context(model: Module)

所有校准前向传递应发生的上下文。

  • 移除梯度计算。
  • 禁用 KV 缓存。
  • 禁用训练模式并启用评估模式。
  • 禁用 hf 内核,这可能会绕过钩子。
  • 禁用 lm head(输入和权重仍可校准,输出将为 meta)。
Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def calibration_forward_context(model: torch.nn.Module):
    """
    Context in which all calibration forward passes should occur.

    - Remove gradient calculations
    - Disable the KV cache
    - Disable train mode and enable eval mode
    - Disable hf kernels which could bypass hooks
    - Disable lm head (input and weights can still be calibrated, output will be meta)
    """
    with contextlib.ExitStack() as stack:
        stack.enter_context(torch.no_grad())
        stack.enter_context(disable_cache(model))
        stack.enter_context(eval_context(model))
        stack.enter_context(disable_hf_kernels(model))
        stack.enter_context(disable_lm_head(model))
        yield

clean_path

clean_path(path: str) -> str

参数

  • 路径

    (str) –

    要清理的目录或文件路径。

返回

  • str

    一个清理后的版本,它会展开用户路径并创建绝对路径。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def clean_path(path: str) -> str:
    """
    :param path: the directory or file path to clean
    :return: a cleaned version that expands the user path and creates an absolute path
    """
    return os.path.abspath(os.path.expanduser(path))

convert_to_bool

convert_to_bool(val: Any)

参数

  • 验证

    (Any) –

    要转换为布尔值的值,支持字符串形式的逻辑值,例如 True、t、false、0。

返回

  • 值的布尔表示形式,如果无法确定,则回退为返回 True。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def convert_to_bool(val: Any):
    """
    :param val: the value to be converted to a bool,
        supports logical values as strings ie True, t, false, 0
    :return: the boolean representation of the value, if it can't be determined,
        falls back on returning True
    """
    return (
        bool(val)
        if not isinstance(val, str)
        else bool(val) and "f" not in val.lower() and "0" not in val.lower()
    )

create_dirs

create_dirs(path: str)

参数

  • 路径

    (str) –

    要尝试创建的目录路径。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def create_dirs(path: str):
    """
    :param path: the directory path to try and create
    """
    path = clean_path(path)

    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

create_parent_dirs

create_parent_dirs(path: str)

参数

  • 路径

    (str) –

    要为其尝试创建父目录的文件路径。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def create_parent_dirs(path: str):
    """
    :param path: the file path to try to create the parent directories for
    """
    parent = os.path.dirname(path)
    create_dirs(parent)

create_unique_dir

create_unique_dir(path: str, check_number: int = 0) -> str

参数

  • 路径

    (str) –

    要创建唯一版本的文件路径(附加数字直到其中一个不存在)。

  • check_number

    (int, 默认值: 0 ) –

    开始检查唯一版本的数字。

返回

  • str

    唯一的目录路径。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def create_unique_dir(path: str, check_number: int = 0) -> str:
    """
    :param path: the file path to create a unique version of
        (append numbers until one doesn't exist)
    :param check_number: the number to begin checking for unique versions at
    :return: the unique directory path
    """
    check_path = clean_path("{}-{:04d}".format(path, check_number))

    if not os.path.exists(check_path):
        return check_path

    return create_unique_dir(path, check_number + 1)

disable_cache

disable_cache(module: Module)

暂时禁用 Transformer 模型的键值缓存。用于防止在仅执行预填充阶段而不执行生成阶段的一次性情况下占用过多内存。

示例

model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") input = torch.randint(0, 32, size=(1, 32)) with disable_cache(model): ... output = model(input)

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def disable_cache(module: torch.nn.Module):
    """
    Temporarily disable the key-value cache for transformer models. Used to prevent
    excess memory use in one-shot cases where the model only performs the prefill
    phase and not the generation phase.

    Example:
    >>> model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    >>> input = torch.randint(0, 32, size=(1, 32))
    >>> with disable_cache(model):
    ...     output = model(input)
    """

    if isinstance(module, PreTrainedModel):
        config = module.config
        config = getattr(config, "text_config", config)
        with patch_attr(config, "use_cache", False):
            yield

    else:
        yield

disable_hf_kernels

disable_hf_kernels(module: Module)

在 transformers>=4.50.0 中,某些模块的前向方法可能会被 hf hub 内核调用替换。这可能会绕过 LLM Compressor 添加的钩子。

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def disable_hf_kernels(module: torch.nn.Module):
    """
    In transformers>=4.50.0, some module forward methods may be
    replaced by calls to hf hub kernels. This has the potential
    to bypass hooks added by LLM Compressor
    """
    if isinstance(module, PreTrainedModel):
        with patch_attr(module.config, "disable_custom_kernels", True):
            yield

    else:
        yield

disable_lm_head

disable_lm_head(model: Module)

通过将 lm_head 移动到 meta 设备来禁用模型的 lm_head。此函数不会解开参数,并在退出时恢复模型的正确加载。

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def disable_lm_head(model: torch.nn.Module):
    """
    Disable the lm_head of a model by moving it to the meta device. This function
    does not untie parameters and restores the model proper loading upon exit
    """
    _, lm_head = get_embeddings(model)
    if lm_head is None:
        logger.warning(
            f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
            "but was unable to to find lm_head. This may lead to unexpected OOM."
        )
        yield
        return

    elif not isinstance(lm_head, torch.nn.Linear):
        logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}")
        yield
        return

    else:
        dummy_weight = lm_head.weight.to("meta")

        def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
            return input.to("meta") @ dummy_weight.T

        with contextlib.ExitStack() as stack:
            lm_head_forward = dummy_forward.__get__(lm_head)
            stack.enter_context(patch_attr(lm_head, "forward", lm_head_forward))

            if hasattr(model, "_hf_hook"):
                stack.enter_context(patch_attr(model._hf_hook, "io_same_device", False))

            yield

dispatch_for_generation

dispatch_for_generation(
    model: PreTrainedModel,
) -> PreTrainedModel

分派模型的自回归生成。这意味着模块将被平均分派到可用设备上,并尽可能保持加载状态。移除任何先前可能存在的 HF 钩子。

参数

  • model

    (PreTrainedModel) –

    要分发的模型

返回

  • PreTrainedModel

    要分派的模型。

Source code in llmcompressor/utils/dev.py
def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
    """
    Dispatch a model autoregressive generation. This means that modules are dispatched
    evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks
    that may have existed previously.

    :param model: model to dispatch
    :return: model which is dispatched
    """
    remove_dispatch(model)

    no_split_module_classes = model._get_no_split_modules("auto")
    max_memory = get_balanced_memory(
        model,
        dtype=model.dtype,
        no_split_module_classes=no_split_module_classes,
    )
    device_map = infer_auto_device_map(
        model,
        dtype=model.dtype,
        max_memory=max_memory,
        no_split_module_classes=no_split_module_classes,
    )

    return dispatch_model(model, device_map=device_map)

eval_context

eval_context(module: Module)

禁用给定模块的 PyTorch 训练模式。

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def eval_context(module: torch.nn.Module):
    """
    Disable pytorch training mode for the given module
    """
    restore_value = module.training
    try:
        module.train(False)  # equivalent to eval()
        yield

    finally:
        module.train(restore_value)

flatten_iterable

flatten_iterable(li: Iterable)

参数

  • li

    (Iterable) –

    要展平的可能嵌套的可迭代对象。

返回

  • 一个展平后的列表,所有元素都按深度优先模式放在一个列表中。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def flatten_iterable(li: Iterable):
    """
    :param li: a possibly nested iterable of items to be flattened
    :return: a flattened version of the list where all elements are in a single list
             flattened in a depth first pattern
    """

    def _flatten_gen(_li):
        for el in _li:
            if isinstance(el, Iterable) and not isinstance(el, (str, bytes)):
                yield from _flatten_gen(el)
            else:
                yield el

    return list(_flatten_gen(li))

get_embeddings

get_embeddings(
    model: PreTrainedModel,
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]

返回模型的输入和输出嵌入。如果 get_input_embeddings/ get_output_embeddings 在模型上未实现,则返回 None。

参数

  • model

    (PreTrainedModel) –

    要从中获取嵌入的模型。

返回

  • tuple[Module | None, Module | None]

    包含嵌入模块或 None 的元组。

源代码位于 llmcompressor/utils/transformers.py
def get_embeddings(
    model: PreTrainedModel,
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
    """
    Returns input and output embeddings of a model. If `get_input_embeddings`/
    `get_output_embeddings` is not implemented on the model, then None will be returned
    instead.

    :param model: model to get embeddings from
    :return: tuple of containing embedding modules or none
    """
    try:
        input_embed = model.get_input_embeddings()

    except (AttributeError, NotImplementedError):
        input_embed = None

    try:
        output_embed = model.get_output_embeddings()
    except (AttributeError, NotImplementedError):
        output_embed = None

    return input_embed, output_embed

getattr_chain

getattr_chain(
    obj: Any, chain_str: str, *args, **kwargs
) -> Any

链式调用多个 getattr 调用,用 . 分隔。

参数

  • obj

    (Any) –

    正在检索属性的基础对象。

  • chain_str

    (str) –

    . 分隔的属性名。

  • default

    默认值,否则抛出错误。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
    """
    Chain multiple getattr calls, separated by `.`

    :param obj: base object whose attributes are being retrieved
    :param chain_str: attribute names separated by `.`
    :param default: default value, throw error otherwise

    """
    if len(args) >= 1:
        has_default = True
        default = args[0]
    elif "default" in kwargs:
        has_default = True
        default = kwargs["default"]
    else:
        has_default = False

    attr_names = chain_str.split(".")

    res = obj
    for attr_name in attr_names:
        if not hasattr(res, attr_name):
            if has_default:
                return default
            else:
                raise AttributeError(f"{res} object has no attribute {attr_name}")
        res = getattr(res, attr_name)

    return res

import_from_path

import_from_path(path: str) -> str

导入由 : 分隔的模块和函数/类名。示例:path = "/path/to/file.py:func_or_class_name" path = "/path/to/file:focn" path = "path.to.file:focn"

参数

  • 路径

    (str) –

    包含文件路径和对象名的路径。

Source code in llmcompressor/utils/helpers.py
def import_from_path(path: str) -> str:
    """
    Import the module and the name of the function/class separated by :
    Examples:
      path = "/path/to/file.py:func_or_class_name"
      path = "/path/to/file:focn"
      path = "path.to.file:focn"
    :param path: path including the file path and object name
    :return Function or class object
    """
    original_path, class_name = path.split(":")
    _path = original_path

    path = original_path.split(".py")[0]
    path = re.sub(r"/+", ".", path)
    try:
        module = importlib.import_module(path)
    except ImportError:
        raise ImportError(f"Cannot find module with path {_path}")

    try:
        return getattr(module, class_name)
    except AttributeError:
        raise AttributeError(f"Cannot find {class_name} in {_path}")

interpolate

interpolate(
    x_cur: float,
    x0: float,
    x1: float,
    y0: Any,
    y1: Any,
    inter_func: str = "linear",
) -> Any

注意,将值限制在 x0 的最小值和 x1 的最大值,设计为不在此范围外工作,出于实现原因。

参数

  • x_cur

    (float) –

    x 的当前值,应介于 x0 和 x1 之间。

  • x0

    (float) –

    x 插值之间的最小值。

  • x1

    (float) –

    x 插值之间的最大值。

  • y0

    (Any) –

    y 插值之间的最小值。

  • y1

    (Any) –

    y 插值之间的最大值。

  • inter_func

    (str, default: 'linear' ) –

    用于插值函数的类型:linear、cubic、inverse_cubic。

返回

  • Any

    将 x 映射到 y 的插值值,用于给定的插值函数。

Source code in llmcompressor/utils/helpers.py
@deprecated(future_name="torch.lerp")
def interpolate(
    x_cur: float, x0: float, x1: float, y0: Any, y1: Any, inter_func: str = "linear"
) -> Any:
    """
    note, caps values at their min of x0 and max x1,
    designed to not work outside of that range for implementation reasons

    :param x_cur: the current value for x, should be between x0 and x1
    :param x0: the minimum for x to interpolate between
    :param x1: the maximum for x to interpolate between
    :param y0: the minimum for y to interpolate between
    :param y1: the maximum for y to interpolate between
    :param inter_func: the type of function to interpolate with:
        linear, cubic, inverse_cubic
    :return: the interpolated value projecting x into y for the given
        interpolation function
    """
    if inter_func not in INTERPOLATION_FUNCS:
        raise ValueError(
            "unsupported inter_func given of {} must be one of {}".format(
                inter_func, INTERPOLATION_FUNCS
            )
        )

    # convert our x to 0-1 range since equations are designed to fit in
    # (0,0)-(1,1) space
    x_per = (x_cur - x0) / (x1 - x0)

    # map x to y using the desired function in (0,0)-(1,1) space
    if inter_func == "linear":
        y_per = x_per
    elif inter_func == "cubic":
        # https://www.wolframalpha.com/input/?i=1-(1-x)%5E3+from+0+to+1
        y_per = 1 - (1 - x_per) ** 3
    elif inter_func == "inverse_cubic":
        # https://www.wolframalpha.com/input/?i=1-(1-x)%5E(1%2F3)+from+0+to+1
        y_per = 1 - (1 - x_per) ** (1 / 3)
    else:
        raise ValueError(
            "unsupported inter_func given of {} in interpolate".format(inter_func)
        )

    if y_per <= 0.0 + sys.float_info.epsilon:
        return y0

    if y_per >= 1.0 - sys.float_info.epsilon:
        return y1

    # scale the threshold based on what we want the current to be
    return y_per * (y1 - y0) + y0

interpolate_list_linear

interpolate_list_linear(
    measurements: List[Tuple[float, float]],
    x_val: Union[float, List[float]],
) -> List[Tuple[float, float]]

线性插值测量值列表内的输入值。

参数

  • measurements

    (List[Tuple[float, float]]) –

    要从中插值的测量值列表。

  • x_val

    (Union[float, List[float]]) –

    要插值到第二个维度的目标值。

返回

  • List[Tuple[float, float]]

    包含目标值和插值值的元组列表。

Source code in llmcompressor/utils/helpers.py
@deprecated(future_name="torch.lerp")
def interpolate_list_linear(
    measurements: List[Tuple[float, float]], x_val: Union[float, List[float]]
) -> List[Tuple[float, float]]:
    """
    interpolate for input values within a list of measurements linearly

    :param measurements: the measurements to interpolate the output value between
    :param x_val: the target values to interpolate to the second dimension
    :return: a list of tuples containing the target values, interpolated values
    """
    assert len(measurements) > 1
    measurements.sort(key=lambda v: v[0])

    x_vals = [x_val] if isinstance(x_val, float) else x_val
    x_vals.sort()

    interpolated = []
    lower_index = 0
    higher_index = 1

    for x_val in x_vals:
        while (
            x_val > measurements[higher_index][0]
            and higher_index < len(measurements) - 1
        ):
            lower_index += 1
            higher_index += 1

        x0, y0 = measurements[lower_index]
        x1, y1 = measurements[higher_index]
        y_val = y0 + (x_val - x0) * ((y1 - y0) / (x1 - x0))
        interpolated.append((x_val, y_val))

    return interpolated

interpolated_integral

interpolated_integral(
    measurements: List[Tuple[float, float]],
)

计算一组形式为 [(x0, y0), (x1, y1), ...] 的测量的插值积分。

参数

  • measurements

    (List[Tuple[float, float]]) –

    要计算积分的测量值。

返回

  • 给定测量值的积分或曲线下的面积。

Source code in llmcompressor/utils/helpers.py
@deprecated(future_name="torch.lerp")
def interpolated_integral(measurements: List[Tuple[float, float]]):
    """
    Calculate the interpolated integal for a group of measurements of the form
    [(x0, y0), (x1, y1), ...]

    :param measurements: the measurements to calculate the integral for
    :return: the integral or area under the curve for the measurements given
    """
    if len(measurements) < 1:
        return 0.0

    if len(measurements) == 1:
        return measurements[0][1]

    measurements.sort(key=lambda v: v[0])
    integral = 0.0

    for index, (x_val, y_val) in enumerate(measurements):
        if index >= len(measurements) - 1:
            continue

        x_next, y_next = measurements[index + 1]
        x_dist = x_next - x_val
        area = y_val * x_dist + (y_next - y_val) * x_dist / 2.0
        integral += area

    return integral

is_package_available

is_package_available(
    package_name: str, return_version: bool = False
) -> Union[Tuple[bool, str], bool]

一个辅助函数,用于检查包是否可用,并可选择返回其版本。此函数强制检查包是否可用,而不仅仅是名称与包相同的目录/文件。

灵感来自:https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/utils/import_utils.py#L41

参数

  • package_name

    (str) –

    要检查的包名。

  • return_version

    (bool, 默认值: False ) –

    如果可用,则返回包版本,设置为 True。

返回

  • Union[Tuple[bool, str], bool]

    如果包可用,则返回 True,否则返回 False;如果 return_version 为 True,则返回 (bool, version) 的元组。

Source code in llmcompressor/utils/helpers.py
def is_package_available(
    package_name: str,
    return_version: bool = False,
) -> Union[Tuple[bool, str], bool]:
    """
    A helper function to check if a package is available
    and optionally return its version. This function enforces
    a check that the package is available and is not
    just a directory/file with the same name as the package.

    inspired from:
    https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/utils/import_utils.py#L41

    :param package_name: The package name to check for
    :param return_version: True to return the version of
        the package if available
    :return: True if the package is available, False otherwise or a tuple of
        (bool, version) if return_version is True
    """

    package_exists = importlib.util.find_spec(package_name) is not None
    package_version = "N/A"
    if package_exists:
        try:
            package_version = importlib.metadata.version(package_name)
            package_exists = True
        except importlib.metadata.PackageNotFoundError:
            package_exists = False
        logger.debug(f"Detected {package_name} version {package_version}")
    if return_version:
        return package_exists, package_version
    else:
        return package_exists

is_url

is_url(val: str)

参数

  • 验证

    (str) –

    要检查的值是否为 url。

返回

  • 如果值为 URL,则返回 True,否则返回 False。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def is_url(val: str):
    """
    :param val: value to check if it is a url or not
    :return: True if value is a URL, False otherwise
    """

    try:
        result = urlparse(val)

        return all([result.scheme, result.netloc])
    except ValueError:
        return False

json_to_jsonl

json_to_jsonl(json_file_path: str, overwrite: bool = True)

将 json 列表文件转换为 jsonl 文件格式(用于分片效率)。例如:[{"a": 1}, {"a": 1}] 将转换为:{"a": 1}

参数

  • json_file_path

    (str) –

    包含 json 对象列表的 json 文件路径。

  • overwrite

    (bool, 默认值: True ) –

    如果为 True,将覆盖现有 json 文件;如果为 False,文件将具有相同的名称,但扩展名为 .jsonl。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def json_to_jsonl(json_file_path: str, overwrite: bool = True):
    """
    Converts a json list file to jsonl file format (used for sharding efficienty)
        e.x.
            [{"a": 1}, {"a": 1}]
        would convert to:
            {"a": 1}
            {"a": 1}
    :param json_file_path: file path to a json file path containing a json list
        of objects
    :param overwrite: If True, the existing json file will be overwritten, if False,
        the file will have the same name but with a .jsonl extension
    """
    if not json_file_path.endswith(".json"):
        raise ValueError("json file must have .json extension")
    with open(json_file_path) as json_file:
        json_data = json.load(json_file)

    if not isinstance(json_data, List):
        raise ValueError(
            "Json data must be a list to conver to jsonl format. "
            f"found {type(json_data)}"
        )

    jsonl_file_path = json_file_path + ("" if overwrite else "l")
    with open(jsonl_file_path, "w") as jsonl_file:
        for json_line in json_data:
            json.dump(json_line, jsonl_file)  # append json line
            jsonl_file.write("\n")  # newline

load_labeled_data

load_labeled_data(
    data: Union[
        str,
        Iterable[Union[str, ndarray, Dict[str, ndarray]]],
    ],
    labels: Union[
        None,
        str,
        Iterable[Union[str, ndarray, Dict[str, ndarray]]],
    ],
    raise_on_error: bool = True,
) -> List[
    Tuple[
        Union[numpy.ndarray, Dict[str, numpy.ndarray]],
        Union[
            None, numpy.ndarray, Dict[str, numpy.ndarray]
        ],
    ]
]

从磁盘或内存加载标签和数据并将它们组合在一起。假定磁盘上的排序顺序。当为 data 和/或 labels 提供文件 glob 时,它们将匹配。

参数

  • data

    (Union[str, Iterable[Union[str, ndarray, Dict[str, ndarray]]]]) –

    用于数据的 glob 文件、tar 包文件路径或数组列表。

  • labels

    (Union[None, str, Iterable[Union[str, ndarray, Dict[str, ndarray]]]]) –

    用于标签的 glob 文件、tar 包文件路径或数组列表,如果有的话。

  • raise_on_error

    (bool, 默认值: True ) –

    如果为 True,则在发生任何错误时引发;如果为 False,则记录警告、忽略并继续。

返回

  • List[Tuple[Union[ndarray, Dict[str, ndarray]], Union[None, ndarray, Dict[str, ndarray]]]]

    一个列表,包含数据和标签的元组。如果 labels 为 None,则在每个元组的第二个索引处包含 None。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def load_labeled_data(
    data: Union[str, Iterable[Union[str, numpy.ndarray, Dict[str, numpy.ndarray]]]],
    labels: Union[
        None, str, Iterable[Union[str, numpy.ndarray, Dict[str, numpy.ndarray]]]
    ],
    raise_on_error: bool = True,
) -> List[
    Tuple[
        Union[numpy.ndarray, Dict[str, numpy.ndarray]],
        Union[None, numpy.ndarray, Dict[str, numpy.ndarray]],
    ]
]:
    """
    Load labels and data from disk or from memory and group them together.
    Assumes sorted ordering for on disk. Will match between when a file glob is passed
    for either data and/or labels.

    :param data: the file glob, file path to numpy data tar ball, or list of arrays to
        use for data
    :param labels: the file glob, file path to numpy data tar ball, or list of arrays
        to use for labels, if any
    :param raise_on_error: True to raise on any error that occurs;
        False to log a warning, ignore, and continue
    :return: a list containing tuples of the data, labels. If labels was passed in
        as None, will now contain a None for the second index in each tuple
    """
    if isinstance(data, str):
        data = load_numpy_list(data)

    if labels is None:
        labels = [None for _ in range(len(data))]
    elif isinstance(labels, str):
        labels = load_numpy_list(labels)

    if len(data) != len(labels) and labels:
        # always raise this error, lengths must match
        raise ValueError(
            "len(data) given of {} does not match len(labels) given of {}".format(
                len(data), len(labels)
            )
        )

    labeled_data = []

    for dat, lab in zip(data, labels):
        try:
            if isinstance(dat, str):
                dat = load_numpy(dat)

            if lab is not None and isinstance(lab, str):
                lab = load_numpy(lab)

            labeled_data.append((dat, lab))
        except Exception as err:
            if raise_on_error:
                raise err
            else:
                logger.error("Error creating labeled data: {}".format(err))

    return labeled_data

load_numpy

load_numpy(
    file_path: str,
) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]

将 numpy 文件加载到 ndarray 或 OrderedDict 中,表示 npz 文件中的内容。

参数

  • 文件路径

    (str) –

    要加载的文件路径。

返回

  • Union[ndarray, Dict[str, ndarray]]

    从文件中加载的值。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def load_numpy(file_path: str) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
    """
    Load a numpy file into either an ndarray or an OrderedDict representing what
    was in the npz file

    :param file_path: the file_path to load
    :return: the loaded values from the file
    """
    file_path = clean_path(file_path)
    array = numpy.load(file_path)

    if not isinstance(array, numpy.ndarray):
        tmp_arrray = array
        array = OrderedDict()
        for key, val in tmp_arrray.items():
            array[key] = val

    return array

patch_attr

patch_attr(base: object, attr: str, value: Any)

修补对象属性的值。原始值在退出时恢复。

参数

  • base

    (object) –

    拥有要修补的属性的对象。

  • attr

    (str) –

    要修补的属性的名称。

  • (Any) –

    用于替换原始值。用法:>>> from types import SimpleNamespace >>> obj = SimpleNamespace() >>> with patch_attr(obj, "attribute", "value"): ... assert obj.attribute == "value" >>> assert not hasattr(obj, "attribute")

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def patch_attr(base: object, attr: str, value: Any):
    """
    Patch the value of an object attribute. Original value is restored upon exit

    :param base: object which has the attribute to patch
    :param attr: name of the the attribute to patch
    :param value: used to replace original value

    Usage:
    >>> from types import SimpleNamespace
    >>> obj = SimpleNamespace()
    >>> with patch_attr(obj, "attribute", "value"):
    ...     assert obj.attribute == "value"
    >>> assert not hasattr(obj, "attribute")
    """
    _sentinel = object()
    original_value = getattr(base, attr, _sentinel)

    setattr(base, attr, value)
    try:
        yield
    finally:
        if original_value is not _sentinel:
            setattr(base, attr, original_value)
        else:
            delattr(base, attr)

patch_transformers_logger_level

patch_transformers_logger_level(level: int = logging.ERROR)

修改 transformers 日志记录器级别的上下文。

可以与 skip_weights_download 一起使用,以抑制与检查点中缺失参数相关的警告。

参数

  • 级别

    (int, default: ERROR ) –

    transformers 日志记录器的新日志级别。低于此级别的日志将不会被记录。

Source code in llmcompressor/utils/dev.py
@contextlib.contextmanager
def patch_transformers_logger_level(level: int = logging.ERROR):
    """
    Context under which the transformers logger's level is modified

    This can be used with `skip_weights_download` to squelch warnings related to
    missing parameters in the checkpoint

    :param level: new logging level for transformers logger. Logs whose level is below
        this level will not be logged
    """
    transformers_logger = logging.getLogger("transformers.modeling_utils")
    restore_log_level = transformers_logger.getEffectiveLevel()

    transformers_logger.setLevel(level=level)
    yield
    transformers_logger.setLevel(level=restore_log_level)

path_file_count

path_file_count(path: str, pattern: str = '*') -> int

返回给定路径下匹配给定模式的文件数。

参数

  • 路径

    (str) –

    要在其下查找文件的目录路径。

  • pattern

    (str, default: '*' ) –

    文件必须匹配的模式才能被计数。

返回

  • int

    目录中匹配模式的文件数。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def path_file_count(path: str, pattern: str = "*") -> int:
    """
    Return the number of files that match the given pattern under the given path

    :param path: the path to the directory to look for files under
    :param pattern: the pattern the files must match to be counted
    :return: the number of files matching the pattern under the directory
    """
    path = clean_path(path)

    return len(fnmatch.filter(os.listdir(path), pattern))

path_file_size

path_file_size(path: str) -> int

返回文件系统上给定路径的总大小(以字节为单位)。

参数

  • 路径

    (str) –

    要获取大小的路径(目录或文件)。

返回

  • int

    存储在磁盘上的路径大小(以字节为单位)。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def path_file_size(path: str) -> int:
    """
    Return the total size, in bytes, for a path on the file system

    :param path: the path (directory or file) to get the size for
    :return: the size of the path, in bytes, as stored on disk
    """

    if not os.path.isdir(path):
        stat = os.stat(path)

        return stat.st_size

    total_size = 0
    seen = {}

    for dir_path, dir_names, filenames in os.walk(path):
        for file in filenames:
            file_path = os.path.join(dir_path, file)

            try:
                stat = os.stat(file_path)
            except OSError:
                continue

            try:
                seen[stat.st_ino]
            except KeyError:
                seen[stat.st_ino] = True
            else:
                continue

            total_size += stat.st_size

    return total_size

save_numpy

save_numpy(
    array: Union[
        ndarray, Dict[str, ndarray], Iterable[ndarray]
    ],
    export_dir: str,
    name: str,
    npz: bool = True,
)

将 numpy 数组或 numpy 数组集合保存到磁盘。

参数

  • array

    (Union[ndarray, Dict[str, ndarray], Iterable[ndarray]]) –

    要保存的数组或数组集合。

  • export_dir

    (str) –

    要导出 numpy 文件的目录。

  • 名称

    (str) –

    要导出到的文件名(不带扩展名)。

  • npz

    (bool, 默认值: True ) –

    True 表示保存为 npz 压缩文件,False 表示标准 npy。注意,npy 只能用于单个 numpy 数组。

返回

  • 保存的路径。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def save_numpy(
    array: Union[numpy.ndarray, Dict[str, numpy.ndarray], Iterable[numpy.ndarray]],
    export_dir: str,
    name: str,
    npz: bool = True,
):
    """
    Save a numpy array or collection of numpy arrays to disk

    :param array: the array or collection of arrays to save
    :param export_dir: the directory to export the numpy file into
    :param name: the name of the file to export to (without extension)
    :param npz: True to save as an npz compressed file, False for standard npy.
        Note, npy can only be used for single numpy arrays
    :return: the saved path
    """
    create_dirs(export_dir)
    export_path = os.path.join(
        export_dir, "{}.{}".format(name, "npz" if npz else "npy")
    )

    if isinstance(array, numpy.ndarray) and npz:
        numpy.savez_compressed(export_path, array)
    elif isinstance(array, numpy.ndarray):
        numpy.save(export_path, array)
    elif isinstance(array, Dict) and npz:
        numpy.savez_compressed(export_path, **array)
    elif isinstance(array, Dict):
        raise ValueError("Dict can only be exported to an npz file")
    elif isinstance(array, Iterable) and npz:
        numpy.savez_compressed(export_path, *[val for val in array])
    elif isinstance(array, Iterable):
        raise ValueError("Iterable can only be exported to an npz file")
    else:
        raise ValueError("Unrecognized type given for array {}".format(array))

    return export_path

skip_weights_download

skip_weights_download(
    model_class: Type[
        PreTrainedModel
    ] = AutoModelForCausalLM,
)

在此上下文中初始化模型,而无需下载模型权重文件。这与 init_empty_weights 不同,后者将权重分配到 meta 设备上,而此处权重被分配到指定的设备上并用随机值填充。

参数

  • model_class

    (Type[PreTrainedModel], default: AutoModelForCausalLM ) –

    要修补的类,默认为 AutoModelForCausalLM

Source code in llmcompressor/utils/dev.py
@contextlib.contextmanager
def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM):
    """
    Context manager under which models are initialized without having to download
    the model weight files. This differs from `init_empty_weights` in that weights are
    allocated on to assigned devices with random values, as opposed to being on the meta
    device

    :param model_class: class to patch, defaults to `AutoModelForCausalLM`
    """
    original_fn = model_class.from_pretrained
    weights_files = [
        "*.bin",
        "*.safetensors",
        "*.pth",
        SAFE_WEIGHTS_INDEX_NAME,
        WEIGHTS_INDEX_NAME,
        "*.msgpack",
        "*.pt",
    ]

    @classmethod
    def patched(cls, *args, **kwargs):
        nonlocal tmp_dir

        # intercept model stub
        model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")

        # download files into tmp dir
        os.makedirs(tmp_dir, exist_ok=True)
        snapshot_download(
            repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files
        )

        # make an empty weights file to avoid errors
        weights_file_path = os.path.join(tmp_dir, "model.safetensors")
        save_file({}, weights_file_path, metadata={"format": "pt"})

        # load from tmp dir
        model = original_fn(tmp_dir, **kwargs)

        # replace model_path
        model.name_or_path = model_stub
        model.config._name_or_path = model_stub

        return model

    with tempfile.TemporaryDirectory() as tmp_dir, patch_attr(
        model_class, "from_pretrained", patched
    ), skip_weights_initialize(), patch_transformers_logger_level():
        yield

targets_embeddings

targets_embeddings(
    model: PreTrainedModel,
    targets: NamedModules,
    check_input: bool = True,
    check_output: bool = True,
) -> bool

返回给定目标是否针对模型词嵌入。

参数

  • model

    (PreTrainedModel) –

    包含词嵌入。

  • targets

    (NamedModules) –

    要检查的命名模块。

  • check_input

    (bool, 默认值: True ) –

    是否检查输入嵌入是否被目标。

  • check_output

    (bool, 默认值: True ) –

    是否检查输出嵌入是否被目标。

返回

  • bool

    如果嵌入被目标,则返回 True,否则返回 False。

源代码位于 llmcompressor/utils/transformers.py
def targets_embeddings(
    model: PreTrainedModel,
    targets: NamedModules,
    check_input: bool = True,
    check_output: bool = True,
) -> bool:
    """
    Returns True if the given targets target the word embeddings of the model

    :param model: containing word embeddings
    :param targets: named modules to check
    :param check_input: whether to check if input embeddings are targeted
    :param check_output: whether to check if output embeddings are targeted
    :return: True if embeddings are targeted, False otherwise
    """
    input_embed, output_embed = get_embeddings(model)
    if (check_input and input_embed) is None or (check_output and output_embed is None):
        logger.warning(
            "Cannot check embeddings. If this model has word embeddings, please "
            "implement `get_input_embeddings` and `get_output_embeddings`"
        )
        return False

    targets = set(module for _, module in targets)
    return (check_input and input_embed in targets) or (
        check_output and output_embed in targets
    )

tensor_export

tensor_export(
    tensor: Union[
        ndarray, Dict[str, ndarray], Iterable[ndarray]
    ],
    export_dir: str,
    name: str,
    npz: bool = True,
) -> str

参数

  • tensor

    (Union[ndarray, Dict[str, ndarray], Iterable[ndarray]]) –

    要导出到已保存 numpy 数组文件的张量。

  • export_dir

    (str) –

    要在其中导出文件的目录。

  • 名称

    (str) –

    文件名,将附加 .npy。

  • npz

    (bool, 默认值: True ) –

    True 表示导出为 npz 文件,否则为 False。

返回

  • str

    张量被导出到的 numpy 文件的路径。

Source code in llmcompressor/utils/helpers.py
@deprecated()
def tensor_export(
    tensor: Union[numpy.ndarray, Dict[str, numpy.ndarray], Iterable[numpy.ndarray]],
    export_dir: str,
    name: str,
    npz: bool = True,
) -> str:
    """
    :param tensor: tensor to export to a saved numpy array file
    :param export_dir: the directory to export the file in
    :param name: the name of the file, .npy will be appended to it
    :param npz: True to export as an npz file, False otherwise
    :return: the path of the numpy file the tensor was exported to
    """
    create_dirs(export_dir)
    export_path = os.path.join(
        export_dir, "{}.{}".format(name, "npz" if npz else "npy")
    )

    if isinstance(tensor, numpy.ndarray) and npz:
        numpy.savez_compressed(export_path, tensor)
    elif isinstance(tensor, numpy.ndarray):
        numpy.save(export_path, tensor)
    elif isinstance(tensor, Dict) and npz:
        numpy.savez_compressed(export_path, **tensor)
    elif isinstance(tensor, Dict):
        raise ValueError("tensor dictionaries can only be saved as npz")
    elif isinstance(tensor, Iterable) and npz:
        numpy.savez_compressed(export_path, *tensor)
    elif isinstance(tensor, Iterable):
        raise ValueError("tensor iterables can only be saved as npz")
    else:
        raise ValueError("unknown type give for tensor {}".format(tensor))

    return export_path

tensors_export

tensors_export(
    tensors: Union[
        ndarray, Dict[str, ndarray], Iterable[ndarray]
    ],
    export_dir: str,
    name_prefix: str,
    counter: int = 0,
    break_batch: bool = False,
) -> List[str]

参数

  • tensors

    (Union[ndarray, Dict[str, ndarray], Iterable[ndarray]]) –

    要导出到已保存 numpy 数组文件的张量。

  • export_dir

    (str) –

    要在其中导出文件的目录。

  • name_prefix

    (str) –

    要保存为张量的前缀名称,将附加有关张量在列表或字典中的位置的信息,此外还有 .npy 文件格式

  • counter

    (int, 默认值: 0 ) –

    当前保存张量的计数器

  • break_batch

    (bool, 默认值: False ) –

    将张量视为一个批次并拆分为多个张量

返回

  • List[str]

    导出的路径

Source code in llmcompressor/utils/helpers.py
@deprecated()
def tensors_export(
    tensors: Union[numpy.ndarray, Dict[str, numpy.ndarray], Iterable[numpy.ndarray]],
    export_dir: str,
    name_prefix: str,
    counter: int = 0,
    break_batch: bool = False,
) -> List[str]:
    """
    :param tensors: the tensors to export to a saved numpy array file
    :param export_dir: the directory to export the files in
    :param name_prefix: the prefix name for the tensors to save as, will append
        info about the position of the tensor in a list or dict in addition
        to the .npy file format
    :param counter: the current counter to save the tensor at
    :param break_batch: treat the tensor as a batch and break apart into
        multiple tensors
    :return: the exported paths
    """
    create_dirs(export_dir)
    exported_paths = []

    if break_batch:
        _tensors_export_batch(tensors, export_dir, name_prefix, counter, exported_paths)
    else:
        _tensors_export_recursive(
            tensors, export_dir, name_prefix, counter, exported_paths
        )

    return exported_paths

untie_word_embeddings

untie_word_embeddings(model: PreTrainedModel)

如果可能,解绑词嵌入。如果模型定义中找不到嵌入,此函数将发出警告。

模型配置将更新以反映嵌入现已解绑

参数

  • model

    (PreTrainedModel) –

    包含词嵌入的 transformers 模型

源代码位于 llmcompressor/utils/transformers.py
def untie_word_embeddings(model: PreTrainedModel):
    """
    Untie word embeddings, if possible. This function raises a warning if
    embeddings cannot be found in the model definition.

    The model config will be updated to reflect that embeddings are now untied

    :param model: transformers model containing word embeddings
    """
    input_embed, output_embed = get_embeddings(model)
    if input_embed is None or output_embed is None:
        logger.warning(
            "Cannot untie embeddings. If this model has word embeddings, please "
            "implement `get_input_embeddings` and `get_output_embeddings`"
        )
        return

    # clone data to untie
    for module in (input_embed, output_embed):
        if not has_offloaded_params(module):
            data = module.weight.data
        else:
            data = module._hf_hook.weights_map["weight"]

        requires_grad = module.weight.requires_grad
        untied_param = Parameter(data.clone(), requires_grad=requires_grad)
        register_offload_parameter(module, "weight", untied_param)

    # modify model config
    if hasattr(model.config, "tie_word_embeddings"):
        model.config.tie_word_embeddings = False

validate_str_iterable

validate_str_iterable(
    val: Union[str, Iterable[str]], error_desc: str = ""
) -> Union[str, Iterable[str]]

参数

  • 验证

    (Union[str, Iterable[str]]) –

    要验证的值,检查它是否为列表(并将其展平),否则检查它是否为 **ALL** 或 **ALL_PRUNABLE** 字符串,否则引发 ValueError

  • error_desc

    (str, 默认值: '' ) –

    在 val 无效时用于引发错误的描述

返回

  • Union[str, Iterable[str]]

    参数的已验证版本

Source code in llmcompressor/utils/helpers.py
@deprecated()
def validate_str_iterable(
    val: Union[str, Iterable[str]], error_desc: str = ""
) -> Union[str, Iterable[str]]:
    """
    :param val: the value to validate, check that it is a list (and flattens it),
        otherwise checks that it's an __ALL__ or __ALL_PRUNABLE__ string,
        otherwise raises a ValueError
    :param error_desc: the description to raise an error with in the event that
        the val wasn't valid
    :return: the validated version of the param
    """
    if isinstance(val, str):
        if val.upper() != ALL_TOKEN and val.upper() != ALL_PRUNABLE_TOKEN:
            raise ValueError(
                "unsupported string ({}) given in {}".format(val, error_desc)
            )

        return val.upper()

    if isinstance(val, Iterable):
        return flatten_iterable(val)

    raise ValueError("unsupported type ({}) given in {}".format(val, error_desc))