跳到内容

llmcompressor.transformers.data.data_helpers

函数

get_custom_datasets_from_path

get_custom_datasets_from_path(
    path: str, ext: str = "json"
) -> dict[str, str]

从目录路径获取自定义数据集字典。支持 HF 的 load_dataset 用于本地文件夹数据集 https://hugging-face.cn/docs/datasets/loading

此函数扫描指定目录路径中具有特定扩展名(默认为 '.json')的文件。它构建一个字典,其中键是子目录名称或直接数据集名称(取决于目录结构),值是文件路径(如果只有一个同名文件)或文件路径列表(如果存在多个文件)。

参数

  • 路径

    (str) –

    包含数据集文件的目录路径。

  • ext

    (str, 默认: 'json' ) –

    用于过滤文件的文件扩展名。默认为 'json'。

返回

  • dict[str, str]

    一个将数据集名称映射到其文件路径或文件路径列表的字典。示例:dataset = get_custom_datasets_from_path("/path/to/dataset/directory", "json") 注意:如果数据集按子目录组织,则函数会使用文件路径列表构建字典。如果数据集直接在主目录中找到,则包含在内,并带有其各自的名称。接受: - 路径 train.json test.json val.json - 路径 train data1.json data2.json ... test ... val ...

源代码位于 llmcompressor/transformers/data/data_helpers.py
def get_custom_datasets_from_path(path: str, ext: str = "json") -> dict[str, str]:
    """
    Get a dictionary of custom datasets from a directory path. Support HF's load_dataset
     for local folder datasets https://hugging-face.cn/docs/datasets/loading

    This function scans the specified directory path for files with a
     specific extension (default is '.json').
    It constructs a dictionary where the keys are either subdirectory names or
     direct dataset names (depending on the directory structure)
    and the values are either file paths (if only one file exists with that name) or
     lists of file paths (if multiple files exist).

    :param path: The path to the directory containing the dataset files.
    :param ext: The file extension to filter files by. Default is 'json'.

    :return: A dictionary mapping dataset names to their file paths or lists of
     file paths.

    Example:
        dataset = get_custom_datasets_from_path("/path/to/dataset/directory", "json")

    Note:
        If datasets are organized in subdirectories, the function constructs the
         dictionary with lists of file paths.
        If datasets are found directly in the main directory, they are included with
         their respective names.

    Accepts:
        - path\
            train.json
            test.json
            val.json

        - path\
            train\
                data1.json
                data2.json
                ...
            test\
                ...
            val\
                ...

    """
    data_files = {}

    if any(filename.endswith(ext) for filename in os.listdir(path)):
        # If there are files with the given extension in the path
        for filename in os.listdir(path):
            if filename.endswith(ext):
                name, _ = os.path.splitext(filename)
                data_files[name] = os.path.join(path, filename)
    else:
        # If datasets are organized in subdirectories
        for root, dirs, files in os.walk(path):
            for dir_name in dirs:
                dir_path = os.path.join(root, dir_name)
                dir_dataset = []
                for filename in os.listdir(dir_path):
                    if filename.endswith(ext):
                        file_path = os.path.join(dir_path, filename)
                        dir_dataset.append(file_path)
                if dir_dataset:
                    data_files[dir_name] = dir_dataset

    return transform_dataset_keys(data_files)

get_raw_dataset

get_raw_dataset(
    dataset_args,
    cache_dir: str | None = None,
    streaming: bool | None = False,
    **kwargs,
) -> Dataset

加载 Hugging Face 的原始数据集,如果可用则使用缓存副本

参数

  • cache_dir

    (str | None, 默认值: None ) –

    用于搜索缓存数据集的磁盘位置

  • streaming

    (bool | None, 默认: False ) –

    True 表示从 Hugging Face 流式传输数据,否则下载

返回

  • Dataset

    请求的数据集

源代码位于 llmcompressor/transformers/data/data_helpers.py
def get_raw_dataset(
    dataset_args,
    cache_dir: str | None = None,
    streaming: bool | None = False,
    **kwargs,
) -> Dataset:
    """
    Load the raw dataset from Hugging Face, using cached copy if available

    :param cache_dir: disk location to search for cached dataset
    :param streaming: True to stream data from Hugging Face, otherwise download
    :return: the requested dataset

    """
    raw_datasets = load_dataset(
        dataset_args.dataset,
        dataset_args.dataset_config_name,
        cache_dir=cache_dir,
        streaming=streaming,
        **kwargs,
    )
    return raw_datasets

transform_dataset_keys

transform_dataset_keys(data_files: dict[str, Any])

如果存在与现有键匹配的情况,则将给定输入字典的键转换为 trainvaltest。请注意,只能有一个匹配的文件名。例如:Folder(train_foo.json) -> Folder(train.json) Folder(train1.json, train2.json) -> Same

参数

  • data_files

    (dict[str, Any]) –

    将要转换键的字典

源代码位于 llmcompressor/transformers/data/data_helpers.py
def transform_dataset_keys(data_files: dict[str, Any]):
    """
    Transform dict keys to `train`, `val` or `test` for the given input dict
    if matches exist with the existing keys. Note that there can only be one
    matching file name.
    Ex. Folder(train_foo.json)           -> Folder(train.json)
        Folder(train1.json, train2.json) -> Same

    :param data_files: The dict where keys will be transformed
    """
    keys = set(data_files.keys())

    def transform_dataset_key(candidate: str) -> None:
        for key in keys:
            if candidate in key:
                if key == candidate:
                    return
                val = data_files.pop(key)
                data_files[candidate] = val

    def do_transform(candidate: str) -> bool:
        return sum(candidate in key for key in keys) == 1

    dataset_keys = ("train", "val", "test")
    for dataset_key in dataset_keys:
        if do_transform(dataset_key):
            transform_dataset_key(dataset_key)

    return data_files