跳到内容

llmcompressor.entrypoints.model_free.reindex_fused_weights

函数

  • reindex_fused_weights

    用于重新索引模型 safetensors 文件的脚本,以便所有融合模块

reindex_fused_weights

reindex_fused_weights(
    model_stub: str,
    save_directory: str,
    num_workers: int = 5,
)

用于重新索引模型 safetensors 文件的脚本,以便所有融合模块(gate_up, qkv)都在同一个 safetensors 文件中。这对于 model_free_ptq 的微观尺度方案(NVFP4A16, MXFP4A16)是必需的。

此脚本假设权重局部性;如果一组融合权重不在同一个文件中,则 1. 不完整的集合是最后一个(按字母顺序排序)权重集合 2. 不完整集合的其余部分是下一个文件(按字母顺序排序)。

这个假设适用于大多数模型检查点,即使在常见的权重按字母顺序而非数字顺序排序的情况下也是如此。

参数

  • model_stub

    (str) –

    huggingface 模型中心或本地权重文件的路径

  • save_directory

    (str) –

    重新索引权重文件的输出目录

  • num_workers

    (int, 默认: 5 ) –

    用于保存文件的worker线程数

源代码在 llmcompressor/entrypoints/model_free/reindex_fused_weights.py
def reindex_fused_weights(
    model_stub: str,
    save_directory: str,
    num_workers: int = 5,
):
    """
    Script used to reindex the safetensors files of a model such that all fused modules
    (gate_up, qkv) are in the same safetensors file. This is required by model_free_ptq
    for microscale schemes (NVFP4A16, MXFP4A16)

    This script assumes weight locality; if a set of fused weights are not in a file,
    1. the incomplete set is the last set of weights (sorted alphabetically)
    2. the remainder of the incomplete set is the next file (sorted alphabetically)

    This assumption holds true for most model checkpoints, even in the common case where
    weights are sorted alphabetically and not numerically.

    :param model_stub: huggingface model hub or path to local weights files
    :param save_directory: output directory for reindexed weights files
    :param num_workers: number of worker threads to save files with
    """

    # read files
    model_files = get_checkpoint_files(model_stub)
    index_file = find_safetensors_index_file(model_files)
    if index_file is None:
        raise ValueError(
            "This script is used to modify safetensor file shards, but was "
            "unable to find safetenors index file. No reindexing is required."
        )

    # copy non-weight files
    for file_path, resolved_path in model_files.items():
        save_path = Path(save_directory) / file_path

        if file_path.endswith("safetensors"):
            continue
        else:
            if is_weights_file(file_path):
                logger.warning(f"Skip processing for weights file {file_path}")
            save_path.parent.mkdir(parents=True, exist_ok=True)
            logger.debug(f"Copying {file_path} {save_path}")
            shutil.copyfile(resolved_path, save_path)

    # read index file
    with open(index_file, "r") as file:
        index_file_data = json.load(file)

    weight_map: dict[str, str] = index_file_data["weight_map"]
    final_weight_map: dict[str, str] = {}

    # set up copy executor and carry over
    writers = ThreadPoolExecutor(max_workers=num_workers)
    carry_over_tensors: dict[str, torch.Tensor] = {}

    # iterate in alphabetical order on assumption of weight-file locality
    file_map = invert_mapping(weight_map)
    file_map = sorted(file_map)
    progress = tqdm.tqdm(total=len(file_map))
    for file_name in file_map:
        file_path = model_files[file_name]
        save_path = os.path.join(save_directory, file_name)
        tensors = load_file(file_path)

        if len(carry_over_tensors) > 0:
            # add carryover
            tensors.update(carry_over_tensors)
            logger.info(f"Moved {list(carry_over_tensors.keys())} into {file_name}")
            carry_over_tensors = {}

        tensor_names = sorted(list(tensors.keys()))
        _matches, unmatched_sets = get_fused_names(tensor_names)
        for unmatched in unmatched_sets:
            # move to carry over
            unmatched_tensors = {
                key: tensors[key] for key in unmatched.values() if key is not None
            }
            carry_over_tensors.update(unmatched_tensors)

            # delete from current file
            for key in unmatched_tensors:
                tensor_names.remove(key)
                del tensors[key]

        # save tensors after modification
        writers.submit(_with_progress, save_file, tensors, save_path, progress=progress)
        final_weight_map.update({name: file_name for name in tensor_names})

    total_size = index_file_data["metadata"]["total_size"]
    update_safetensors_index(save_directory, total_size, final_weight_map)
    writers.shutdown(wait=True)