跳到内容

vllm_gaudi.models.utils

_merge_multimodal_embeddings

_merge_multimodal_embeddings(
    inputs_embeds: Tensor,
    multimodal_embeddings: NestedTensors,
    is_multimodal: Tensor,
) -> Tensor

multimodal_embeddings 合并到 inputs_embeds 中,通过覆盖 input_ids 中占位符对应的 inputs_embeds 位置。

注意

此操作会就地更新 inputs_embeds

源代码位于 vllm_gaudi/models/utils.py
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
    is_multimodal: torch.Tensor,
) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
    ``input_ids``.

    Note:
        This updates ``inputs_embeds`` in place.
    """
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    import habana_frameworks.torch.core as htcore
    htcore.mark_step()

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

    try:
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
        # htcore.mark_step()
        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
        # inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
        #                               mm_embeds_flat.to(dtype=input_dtype))

        multimodal_positions = torch.where(is_multimodal)[0][:mm_embeds_flat.shape[0]]
        inputs_embeds[0, multimodal_positions] = mm_embeds_flat.to(dtype=input_dtype)

    except RuntimeError as e:
        num_actual_tokens = len(mm_embeds_flat)
        num_expected_tokens = is_multimodal.sum().item()

        if num_actual_tokens != num_expected_tokens:
            expr = _embedding_count_expression(multimodal_embeddings)

            raise ValueError(f"Attempted to assign {expr} = {num_actual_tokens} "
                             f"multimodal tokens to {num_expected_tokens} placeholders") from e

        raise ValueError("Error during masked scatter operation") from e

    return inputs_embeds