JSON 和 CSV 数据源的自定义数据集实现。
本模块提供了一个 CustomDataset 类,用于加载和处理用于文本生成微调的本地 JSON 和 CSV 文件。支持灵活的数据格式和用户提供的数据集的自定义预处理管道。
类
CustomDataset(
dataset_args: DatasetArguments,
split: str,
processor: Processor,
)
基类:TextGenerationDataset
支持加载 csv 和 json 的自定义本地数据集的子文本生成类
参数
- (
DatasetArguments) – - (
str) – 从数据集中拆分以加载,例如 test 或 train[:5%]。也可以设置为 None 来加载所有拆分。
- (
Processor) –
源代码位于 llmcompressor/transformers/data/base.py
| def __init__(
self,
dataset_args: DatasetArguments,
split: str,
processor: Processor,
):
self.dataset_args = dataset_args
self.split = split
self.processor = processor
# get tokenizer
self.tokenizer = getattr(self.processor, "tokenizer", self.processor)
if self.tokenizer is not None:
# fill in pad token
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
# configure sequence length
max_seq_length = dataset_args.max_seq_length
if dataset_args.max_seq_length > self.tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({max_seq_length}) is larger than "
f"maximum length for model ({self.tokenizer.model_max_length}). "
f"Using max_seq_length={self.tokenizer.model_max_length}."
)
self.max_seq_length = min(
dataset_args.max_seq_length, self.tokenizer.model_max_length
)
# configure padding
self.padding = (
False
if self.dataset_args.concatenate_data
else "max_length"
if self.dataset_args.pad_to_max_length
else False
)
else:
self.max_seq_length = None
self.padding = False
|