跳到内容

speculators.data_generation.preprocessing

函数

build_eagle3_dataset

build_eagle3_dataset(
    dataset: Dataset,
    tokenizer: PreTrainedTokenizer,
    max_length: int = 2048,
    num_proc: int = 8,
    assistant_pattern: str | Pattern[str] | None = None,
    turn_dropout: bool = False,
) -> HFDataset

通过标记对话和创建损失掩码来构建 EAGLE3 数据集。

使用 tokenizer 通过 apply_chat_template 内置的聊天模板。

参数: dataset: 包含对话的原始数据集 tokenizer: 支持聊天模板的 tokenizer max_length: 最大序列长度 num_proc: 并行处理的进程数 assistant_pattern: 可选的自定义正则表达式模式,用于匹配助手响应。如果为 None,将从聊天模板自动检测模式。 turn_dropout: 如果为 True,则随机保留对话的前 N 个连续回合

源代码在 speculators/data_generation/preprocessing.py
def build_eagle3_dataset(
    dataset: HFDataset,
    tokenizer: PreTrainedTokenizer,
    max_length: int = 2048,
    num_proc: int = 8,
    assistant_pattern: str | Pattern[str] | None = None,
    turn_dropout: bool = False,
) -> HFDataset:
    """Build EAGLE3 dataset by tokenizing conversations and creating loss masks.

    Uses the tokenizer's built-in chat template via apply_chat_template.

    Args:
        dataset: Raw dataset with conversations
        tokenizer: Tokenizer with chat template support
        max_length: Maximum sequence length
        num_proc: Number of processes for parallel processing
        assistant_pattern: Optional custom regex pattern for matching assistant
                          responses. If None, pattern will be auto-detected from
                          chat template.
        turn_dropout: If True, randomly keeps first N consecutive turns per
                     conversation
    """
    # Detect and use provided assistant message pattern
    if assistant_pattern is not None:
        log.info(f"Using custom assistant pattern: {str(assistant_pattern)[:80]}...")
    elif _supports_assistant_mask(tokenizer):
        assistant_pattern = None  # Signal to use HF mask in _preprocess_batch
        log.info("Using HF assistant token mask for loss masking")
    else:
        assistant_pattern = _detect_assistant_pattern(tokenizer)
        log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...")

    original_cols = dataset.column_names

    dataset = dataset.map(
        lambda examples: _preprocess_batch(
            examples, tokenizer, max_length, assistant_pattern, turn_dropout
        ),
        batched=True,
        num_proc=num_proc,
        batch_size=1000,
        remove_columns=original_cols,
        load_from_cache_file=True,
    )

    dataset.set_format(type="torch")
    return dataset

load_and_preprocess_dataset

load_and_preprocess_dataset(
    target_model_path: str,
    train_data_path: str,
    seq_length: int,
    build_dataset_num_proc: int = 8,
    seed: int = 0,
    max_samples: int | None = None,
    token_freq_path: str = "./token_freq.pt",
    cache_dir: str | None = None,
    assistant_pattern: str | None = None,
    turn_dropout: bool = False,
) -> tuple[HFDataset, PreTrainedTokenizer]

加载、标记和预处理 EAGLE3 训练的数据集。

使用 tokenizer 通过 apply_chat_template 内置的聊天模板。缓存由 HuggingFace datasets 自动处理。

参数: target_model_path: HuggingFace 模型 ID 或本地路径 train_data_path: 数据集名称或 JSON/JSONL 文件路径 seq_length: 最大序列长度 build_dataset_num_proc: 数据集构建的进程数 seed: 用于随机打乱的种子 max_samples: 可选样本数限制 token_freq_path: 用于保存 token 频率分布的路径 cache_dir: 用于缓存 HuggingFace 数据集的目录(可选) assistant_pattern: 可选的自定义正则表达式模式,用于匹配助手响应。如果为 None,将从聊天模板自动检测模式。 turn_dropout: 如果为 True,则随机保留对话的前 N 个连续回合

返回: (预处理数据集, tokenizer) 元组

源代码在 speculators/data_generation/preprocessing.py
def load_and_preprocess_dataset(
    target_model_path: str,
    train_data_path: str,
    seq_length: int,
    build_dataset_num_proc: int = 8,
    seed: int = 0,
    max_samples: int | None = None,
    token_freq_path: str = "./token_freq.pt",  # noqa: S107
    cache_dir: str | None = None,
    assistant_pattern: str | None = None,
    turn_dropout: bool = False,
) -> tuple[HFDataset, PreTrainedTokenizer]:
    """Load, tokenize, and preprocess a dataset for EAGLE3 training.

    Uses the tokenizer's built-in chat template via apply_chat_template.
    Caching is handled automatically by HuggingFace datasets.

    Args:
        target_model_path: HuggingFace model ID or local path
        train_data_path: Dataset name or path to JSON/JSONL file
        seq_length: Maximum sequence length
        build_dataset_num_proc: Number of processes for dataset building
        seed: Random seed for shuffling
        max_samples: Optional limit on number of samples
        token_freq_path: Path to save token frequency distribution
        cache_dir: Directory to cache HuggingFace datasets (optional)
        assistant_pattern: Optional custom regex pattern for matching assistant
                          responses. If None, pattern will be auto-detected from
                          chat template.
        turn_dropout: If True, randomly keeps first N consecutive turns per
                     conversation

    Returns:
        Tuple of (preprocessed_dataset, tokenizer)
    """
    log.section("Starting dataset preprocessing")

    log.subsection("Loading tokenizer and dataset")
    tokenizer = AutoTokenizer.from_pretrained(target_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None:
        raise ValueError(
            f"Tokenizer for {target_model_path} does not support chat templates. "
            "Please use a model with a pre-configured chat template."
        )

    raw_dataset = load_raw_dataset(
        train_data_path, num_proc=build_dataset_num_proc, cache_dir=cache_dir
    )
    raw_dataset = raw_dataset.shuffle(seed=seed)

    if max_samples is not None and len(raw_dataset) > max_samples:
        raw_dataset = raw_dataset.select(range(max_samples))

    log.info(f"Loaded {len(raw_dataset)} samples")

    log.subsection("Tokenizing and building dataset")
    if cache_dir:
        log.info(f"Preprocessed data will be cached at: {cache_dir}")
    if turn_dropout:
        log.info("Turn dropout enabled: randomly keeping N consecutive turns")

    preprocessed_dataset = build_eagle3_dataset(
        dataset=raw_dataset,
        tokenizer=tokenizer,
        max_length=seq_length,
        num_proc=build_dataset_num_proc,
        assistant_pattern=assistant_pattern,
        turn_dropout=turn_dropout,
    )

    log.subsection("Computing token frequency distribution")
    save_token_frequency_distribution(
        dataset=preprocessed_dataset,
        output_path=token_freq_path,
    )

    log.subsection("Visualizing sample")
    _visualize_sample(raw_dataset, preprocessed_dataset, tokenizer, idx=0)

    log.section("Dataset preprocessing complete")

    return preprocessed_dataset, tokenizer

load_raw_dataset

load_raw_dataset(
    train_data_path: str,
    num_proc: int = 8,
    cache_dir: str | None = None,
) -> HFDataset

从本地文件或 HuggingFace 加载原始数据集。

源代码在 speculators/data_generation/preprocessing.py
def load_raw_dataset(
    train_data_path: str, num_proc: int = 8, cache_dir: str | None = None
) -> HFDataset:
    """Load raw dataset from local file or HuggingFace."""
    if train_data_path.endswith((".jsonl", ".json")):
        return load_dataset(
            "json", data_files=train_data_path, split="train", cache_dir=cache_dir
        )

    if train_data_path not in DATASET_CONFIGS:
        raise ValueError(
            f"Unsupported dataset: {train_data_path}. "
            f"Supported: local .json/.jsonl files or {list(DATASET_CONFIGS.keys())}"
        )

    config = DATASET_CONFIGS[train_data_path]
    raw_dataset = load_dataset(config.hf_path, split=config.split, cache_dir=cache_dir)

    if config.normalize_fn is not None:
        raw_dataset = raw_dataset.map(config.normalize_fn, num_proc=num_proc)

    return raw_dataset