将 multimodal_embeddings 合并到 inputs_embeds 中,通过覆盖 input_ids 中占位符对应的 inputs_embeds 位置。
注意
此操作会就地更新 inputs_embeds。
源代码位于 vllm_gaudi/models/utils.py
| def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
if len(multimodal_embeddings) == 0:
return inputs_embeds
import habana_frameworks.torch.core as htcore
htcore.mark_step()
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# htcore.mark_step()
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
# inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
# mm_embeds_flat.to(dtype=input_dtype))
multimodal_positions = torch.where(is_multimodal)[0][:mm_embeds_flat.shape[0]]
inputs_embeds[0, multimodal_positions] = mm_embeds_flat.to(dtype=input_dtype)
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders") from e
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
|