reindex_fused_weights(
model_stub: str,
save_directory: str,
num_workers: int = 5,
)
用于重新索引模型 safetensors 文件的脚本,以便所有融合模块(gate_up, qkv)都在同一个 safetensors 文件中。这对于 model_free_ptq 的微观尺度方案(NVFP4A16, MXFP4A16)是必需的。
此脚本假设权重局部性;如果一组融合权重不在同一个文件中,则 1. 不完整的集合是最后一个(按字母顺序排序)权重集合 2. 不完整集合的其余部分是下一个文件(按字母顺序排序)。
这个假设适用于大多数模型检查点,即使在常见的权重按字母顺序而非数字顺序排序的情况下也是如此。
参数
-
model_stub
(str) – huggingface 模型中心或本地权重文件的路径
-
save_directory
(str) – -
num_workers
(int, 默认: 5 ) –
源代码在 llmcompressor/entrypoints/model_free/reindex_fused_weights.py
| def reindex_fused_weights(
model_stub: str,
save_directory: str,
num_workers: int = 5,
):
"""
Script used to reindex the safetensors files of a model such that all fused modules
(gate_up, qkv) are in the same safetensors file. This is required by model_free_ptq
for microscale schemes (NVFP4A16, MXFP4A16)
This script assumes weight locality; if a set of fused weights are not in a file,
1. the incomplete set is the last set of weights (sorted alphabetically)
2. the remainder of the incomplete set is the next file (sorted alphabetically)
This assumption holds true for most model checkpoints, even in the common case where
weights are sorted alphabetically and not numerically.
:param model_stub: huggingface model hub or path to local weights files
:param save_directory: output directory for reindexed weights files
:param num_workers: number of worker threads to save files with
"""
# read files
model_files = get_checkpoint_files(model_stub)
index_file = find_safetensors_index_file(model_files)
if index_file is None:
raise ValueError(
"This script is used to modify safetensor file shards, but was "
"unable to find safetenors index file. No reindexing is required."
)
# copy non-weight files
for file_path, resolved_path in model_files.items():
save_path = Path(save_directory) / file_path
if file_path.endswith("safetensors"):
continue
else:
if is_weights_file(file_path):
logger.warning(f"Skip processing for weights file {file_path}")
save_path.parent.mkdir(parents=True, exist_ok=True)
logger.debug(f"Copying {file_path} {save_path}")
shutil.copyfile(resolved_path, save_path)
# read index file
with open(index_file, "r") as file:
index_file_data = json.load(file)
weight_map: dict[str, str] = index_file_data["weight_map"]
final_weight_map: dict[str, str] = {}
# set up copy executor and carry over
writers = ThreadPoolExecutor(max_workers=num_workers)
carry_over_tensors: dict[str, torch.Tensor] = {}
# iterate in alphabetical order on assumption of weight-file locality
file_map = invert_mapping(weight_map)
file_map = sorted(file_map)
progress = tqdm.tqdm(total=len(file_map))
for file_name in file_map:
file_path = model_files[file_name]
save_path = os.path.join(save_directory, file_name)
tensors = load_file(file_path)
if len(carry_over_tensors) > 0:
# add carryover
tensors.update(carry_over_tensors)
logger.info(f"Moved {list(carry_over_tensors.keys())} into {file_name}")
carry_over_tensors = {}
tensor_names = sorted(list(tensors.keys()))
_matches, unmatched_sets = get_fused_names(tensor_names)
for unmatched in unmatched_sets:
# move to carry over
unmatched_tensors = {
key: tensors[key] for key in unmatched.values() if key is not None
}
carry_over_tensors.update(unmatched_tensors)
# delete from current file
for key in unmatched_tensors:
tensor_names.remove(key)
del tensors[key]
# save tensors after modification
writers.submit(_with_progress, save_file, tensors, save_path, progress=progress)
final_weight_map.update({name: file_name for name in tensor_names})
total_size = index_file_data["metadata"]["total_size"]
update_safetensors_index(save_directory, total_size, final_weight_map)
writers.shutdown(wait=True)
|