跳到内容

llmcompressor.transformers.data.custom

JSON 和 CSV 数据源的自定义数据集实现。

本模块提供了一个 CustomDataset 类,用于加载和处理用于文本生成微调的本地 JSON 和 CSV 文件。支持灵活的数据格式和用户提供的数据集的自定义预处理管道。

  • CustomDataset

    支持加载的自定义本地数据集的子文本生成类

CustomDataset

CustomDataset(
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
)

基类:TextGenerationDataset

支持加载 csv 和 json 的自定义本地数据集的子文本生成类

参数

  • dataset_args

    (DatasetArguments) –

    数据集加载的配置设置

  • split

    (str) –

    从数据集中拆分以加载,例如 testtrain[:5%]。也可以设置为 None 来加载所有拆分。

  • processor

    (Processor) –

    要在数据集上使用的处理器或分词器

源代码位于 llmcompressor/transformers/data/base.py
def __init__(
    self,
    dataset_args: DatasetArguments,
    split: str,
    processor: Processor,
):
    self.dataset_args = dataset_args
    self.split = split
    self.processor = processor

    # get tokenizer
    self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

    if self.tokenizer is not None:
        # fill in pad token
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # configure sequence length
        max_seq_length = dataset_args.max_seq_length
        if dataset_args.max_seq_length > self.tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({max_seq_length}) is larger than "
                f"maximum length for model ({self.tokenizer.model_max_length}). "
                f"Using max_seq_length={self.tokenizer.model_max_length}."
            )
        self.max_seq_length = min(
            dataset_args.max_seq_length, self.tokenizer.model_max_length
        )

        # configure padding
        self.padding = (
            False
            if self.dataset_args.concatenate_data
            else "max_length"
            if self.dataset_args.pad_to_max_length
            else False
        )

    else:
        self.max_seq_length = None
        self.padding = False