跳到内容

vllm_gaudi.models.gemma3_mm

HpuGemma3ForConditionalGeneration

基类: Gemma3ForConditionalGeneration

源代码位于 vllm_gaudi/models/gemma3_mm.py
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
                                        info=Gemma3ProcessingInfo,
                                        dummy_inputs=Gemma3DummyInputsBuilder)
class HpuGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

    # For HPU optimization, process the vision tower using buckets to reduce recipe recompilation overhead
    def _process_image_input(self, image_input: Gemma3ImageInputs) -> list[torch.Tensor]:
        assert self.vision_tower is not None
        pixel_values = image_input["pixel_values"]
        num_patches = image_input["num_patches"]

        if hasattr(self, 'vision_bucket_manager'):
            batch_breakdown = self.vision_bucket_manager.greedy_plan(pixel_values.shape[0],
                                                                     self.vision_bucket_manager.multimodal_buckets)
            start_idx = 0
            image_embeds_multibatches = []

            for i in batch_breakdown:
                end_idx = start_idx + i
                indices = torch.arange(start_idx, end_idx).to(pixel_values.device)
                batch_sliced_pixel_values = torch.index_select(pixel_values, dim=0, index=indices)

                image_features = self._image_pixels_to_features(
                    self.vision_tower,
                    batch_sliced_pixel_values,
                )
                image_embeds = self.multi_modal_projector(image_features)
                image_embeds_multibatches += [image_embeds.clone()]
                start_idx = end_idx
            image_embeds = torch.cat(image_embeds_multibatches, dim=0)
        else:
            image_features = self._image_pixels_to_features(
                self.vision_tower,
                pixel_values,
            )
            image_embeds = self.multi_modal_projector(image_features)

        return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
源代码位于 vllm_gaudi/models/gemma3_mm.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__(vllm_config=vllm_config, prefix=prefix)

_process_image_input

_process_image_input(
    image_input: Gemma3ImageInputs,
) -> list[Tensor]
源代码位于 vllm_gaudi/models/gemma3_mm.py
def _process_image_input(self, image_input: Gemma3ImageInputs) -> list[torch.Tensor]:
    assert self.vision_tower is not None
    pixel_values = image_input["pixel_values"]
    num_patches = image_input["num_patches"]

    if hasattr(self, 'vision_bucket_manager'):
        batch_breakdown = self.vision_bucket_manager.greedy_plan(pixel_values.shape[0],
                                                                 self.vision_bucket_manager.multimodal_buckets)
        start_idx = 0
        image_embeds_multibatches = []

        for i in batch_breakdown:
            end_idx = start_idx + i
            indices = torch.arange(start_idx, end_idx).to(pixel_values.device)
            batch_sliced_pixel_values = torch.index_select(pixel_values, dim=0, index=indices)

            image_features = self._image_pixels_to_features(
                self.vision_tower,
                batch_sliced_pixel_values,
            )
            image_embeds = self.multi_modal_projector(image_features)
            image_embeds_multibatches += [image_embeds.clone()]
            start_idx = end_idx
        image_embeds = torch.cat(image_embeds_multibatches, dim=0)
    else:
        image_features = self._image_pixels_to_features(
            self.vision_tower,
            pixel_values,
        )
        image_embeds = self.multi_modal_projector(image_features)

    return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]