跳到内容

llmcompressor.transformers.data.base

为文本生成数据集处理和提供基础支持。

此模块提供了基础的 TextGenerationDataset 类,并支持不同数据集类型的注册。它负责数据集的加载、标记化、预处理以及针对文本生成微调工作流的特定格式化。

TextGenerationDataset

TextGenerationDataset(
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
)

继承自:RegistryMixin

文本数据集的基类。应用以下转换到数据集,以准备数据集供数据加载器加载。

  1. 从 huggingface 或本地缓存加载数据集
  2. 根据预处理函数或聊天/数据集模板预处理数据集
  3. 使用模型分词器/处理器标记化数据集
  4. 应用后处理,例如文本分组和/或为微调添加标签

参数

  • dataset_args

    (DatasetArguments) –

    数据集加载的配置设置

  • split

    (str) –

    从数据集中加载的拆分,例如 testtrain[:5%]

  • processor

    (Processor) –

    要在数据集上使用的处理器或分词器

方法

  • load_dataset

    从 Hugging Face 加载原始数据集,如果可用则使用缓存副本。

  • map

    Dataset.map 和 IterableDataset.map 的包装函数。

属性

  • preprocess (Callable[[LazyRow], Any] | None) –

    该函数必须返回对应于处理器/分词器 kwargs 的键。

源代码在 llmcompressor/transformers/data/base.py
def __init__(
    self,
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
):
    self.dataset_args = dataset_args
    self.split = split
    self.processor = processor

    # get tokenizer
    self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

    if self.tokenizer is not None:
        # fill in pad token
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # configure sequence length
        max_seq_length = dataset_args.max_seq_length
        if dataset_args.max_seq_length > self.tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({max_seq_length}) is larger than "
                f"maximum length for model ({self.tokenizer.model_max_length}). "
                f"Using max_seq_length={self.tokenizer.model_max_length}."
            )
        self.max_seq_length = min(
            dataset_args.max_seq_length, self.tokenizer.model_max_length
        )

        # configure padding
        self.padding = (
            False
            if self.dataset_args.concatenate_data
            else "max_length"
            if self.dataset_args.pad_to_max_length
            else False
        )

    else:
        self.max_seq_length = None
        self.padding = False

preprocess 缓存 属性

preprocess: Callable[[LazyRow], Any] | None

该函数必须返回对应于处理器/分词器 kwargs 的键,可选地包括 PROMPT_KEY。

load_dataset

load_dataset()

从 Hugging Face 加载原始数据集,如果可用则使用缓存副本。

参数

  • cache_dir

    用于搜索缓存数据集的磁盘位置。

返回

  • 请求的数据集。

源代码在 llmcompressor/transformers/data/base.py
def load_dataset(self):
    """
    Load the raw dataset from Hugging Face, using cached copy if available

    :param cache_dir: disk location to search for cached dataset
    :return: the requested dataset
    """
    if self.dataset_args.dataset_path is not None:
        if self.dataset_args.dvc_data_repository is not None:
            self.dataset_args.raw_kwargs["storage_options"] = {
                "url": self.dataset_args.dvc_data_repository
            }
            self.dataset_args.raw_kwargs["data_files"] = (
                self.dataset_args.dataset_path
            )
        else:
            self.dataset_args.raw_kwargs["data_files"] = (
                get_custom_datasets_from_path(
                    self.dataset_args.dataset_path,
                    self.dataset_args.dataset
                    if hasattr(self.dataset_args, "dataset")
                    else self.dataset_args.dataset_name,
                )
            )

    logger.debug(f"Loading dataset {self.dataset_args.dataset}")
    return get_raw_dataset(
        self.dataset_args,
        cache_dir=None,
        split=self.split,
        streaming=self.dataset_args.streaming,
        **self.dataset_args.raw_kwargs,
    )

map

map(
    dataset: Dataset | IterableDataset,
    function: Callable[[Any], Any],
    **kwargs,
) -> Dataset | IterableDataset

Dataset.map 和 IterableDataset.map 的包装函数。

如果数据集是流式的(在 IterableDataset 的情况下),则会忽略不适用的参数,并解析数据集的特征。

源代码在 llmcompressor/transformers/data/base.py
def map(
    self,
    dataset: Dataset | IterableDataset,
    function: Callable[[Any], Any],
    **kwargs,
) -> Dataset | IterableDataset:
    """
    Wrapper function around Dataset.map and IterableDataset.map.

    If the dataset is streaming (in the case of IterableDataset), non-applicable
    arguments are ignored and the dataset features are resolved
    """
    if isinstance(dataset, IterableDataset):
        # remove arguments that don't apply to streaming
        kwargs.pop("num_proc", None)
        kwargs.pop("load_from_cache_file", None)
        kwargs.pop("desc", None)
        kwargs.pop("keep_in_memory", None)

    dataset = dataset.map(function, **kwargs)

    if isinstance(dataset, IterableDataset):
        dataset = dataset._resolve_features()

    return dataset