跳到内容

vllm_gaudi.attention.backends.hpu_attn

logger 模块属性

logger = logger()

HPUAttentionBackend

基类:AttentionBackend

Source code in vllm_gaudi/attention/backends/hpu_attn.py
class HPUAttentionBackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        raise NotImplementedError()

    @staticmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        raise NotImplementedError()

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        raise NotImplementedError()

    @staticmethod
    def get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]:
        return HPUPagedAttentionMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dsts: torch.Tensor,
    ) -> None:
        HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

    @staticmethod
    def copy_blocks(
        kv_caches: list[torch.Tensor],
        src_to_dsts: torch.Tensor,
    ) -> None:
        HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)

copy_blocks 静态方法

copy_blocks(
    kv_caches: list[Tensor], src_to_dsts: Tensor
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def copy_blocks(
    kv_caches: list[torch.Tensor],
    src_to_dsts: torch.Tensor,
) -> None:
    HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)

get_builder_cls 静态方法

get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]:
    return HPUPagedAttentionMetadataBuilder

get_impl_cls 静态方法

get_impl_cls() -> type[AttentionImpl]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
    raise NotImplementedError()

get_kv_cache_shape 静态方法

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size)

get_metadata_cls 静态方法

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    raise NotImplementedError()

get_name 静态方法

get_name() -> str
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_name() -> str:
    raise NotImplementedError()

swap_blocks 静态方法

swap_blocks(
    src_kv_cache: Tensor,
    dst_kv_cache: Tensor,
    src_to_dsts: Tensor,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def swap_blocks(
    src_kv_cache: torch.Tensor,
    dst_kv_cache: torch.Tensor,
    src_to_dsts: torch.Tensor,
) -> None:
    HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

HPUAttentionImpl

基类:AttentionImpl, Module

如果输入张量包含提示(prompt)标记,布局如下: |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

否则,布局如下: |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

当使用 cuda-graph 时,生成标记(generation tokens)可能包含填充(padding)。目前,提示标记不包含任何填充。

提示可能具有不同的长度,而生成标记的长度始终为 1。

Source code in vllm_gaudi/attention/backends/hpu_attn.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prefill_tokens ----------------->|
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

    Otherwise, the layout is as follows:
    |<----------------- num_decode_tokens ------------------>|
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
        use_irope: bool = False,
    ) -> None:
        super(AttentionImpl, self).__init__()
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not currently supported on HPU.")
        if use_irope:
            logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
                                "to global attention for long context.")
        self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
        self.kv_cache_dtype = kv_cache_dtype
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.matmul_qk = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.softmax = Softmax()
        self.matmul_av = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        self.v_cache = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        HPUFusedSDPA = kernels.fsdpa()
        self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
            else ModuleFusedSDPA(HPUFusedSDPA)
        self.prefill_impl = get_config().prompt_attn_impl
        self.use_contiguous_pa = get_config().use_contiguous_pa
        self.use_merged_prefill = get_config().merged_prefill
        if alibi_slopes is not None:
            assert self.prefill_impl != 'flex_impl', \
                'Prefill with Flex Attention not supported with alibi slopes!'
            assert self.prefill_impl != 'fsdpa_impl', \
                'Prefill with FusedSDPA not supported with alibi slopes!'
            assert self.use_contiguous_pa, \
                'Non-contiguous PA not supported with alibi slopes!'

        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.sliding_window = sliding_window
        self.prompt_position_bias = None
        self.prev_attn = None
        self.alibi_slopes = None
        if alibi_slopes is not None:
            slope_tensor_dtype = torch.float32 if \
                get_config().fp32_alibi_biases else torch.bfloat16
            alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=slope_tensor_dtype)
            self.alibi_slopes = alibi_slopes_tensor

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            raise ValueError(f"Head size {head_size} is not supported by PagedAttention. "
                             f"Supported head sizes are: {supported_head_sizes}.")

        self.attn_type = attn_type
        if (self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_DECODER
                and self.attn_type != AttentionType.ENCODER_ONLY):
            raise NotImplementedError("Encoder self-attention "
                                      "is not implemented for "
                                      "HPUAttentionImpl")

        self.is_chunked_attention = False

    def _maybe_init_alibi_biases(
        self,
        max_seq_len,
        prev_attn: Optional[torch.nn.Module] = None,
    ) -> None:
        self.max_seq_len = max_seq_len
        self.prev_attn = None if prev_attn is None else prev_attn.impl
        if self.alibi_slopes is not None:
            if self.prev_attn is not None:
                self.alibi_slopes = self.prev_attn.alibi_slopes
                self.prompt_position_bias = self.prev_attn.prompt_position_bias
            else:
                # Creating the prompt_position_bias once and reusing it
                # if seq_len permits.
                self.prompt_position_bias = _make_prompt_alibi_bias(
                    alibi_slopes=self.alibi_slopes,
                    seq_len=self.max_seq_len,
                    dtype=self.alibi_slopes.dtype,
                )

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
        if self.attn_type == AttentionType.ENCODER_DECODER:
            return self.forward_encoder_decoder(
                query=query,
                key=key,
                value=value,
                kv_cache=kv_cache,
                attn_metadata=attn_metadata,
                k_scale=layer._k_scale_float,
                v_scale=layer._k_scale_float,
            )
        # Set return shape
        output_shape = query.shape
        if query.dim() == 2:
            if attn_metadata.seq_lens_tensor is not None:
                batch_size = attn_metadata.seq_lens_tensor.shape[0] if not self.use_merged_prefill else 1
            else:
                assert attn_metadata.block_mapping is not None, \
                    "seq_lens_tensor must be provided for attention"
                batch_size = attn_metadata.block_mapping.shape[1]
            num_tokens, hidden_size = query.shape
            seq_len = num_tokens // batch_size
            query = query.view(batch_size, seq_len, -1)
        else:
            batch_size, seq_len, hidden_size = query.shape

        seq_len_kv = key.shape[0] // batch_size if key.dim() == 2 else key.shape[1]

        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
        slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None
        key_cache = None
        value_cache = None
        if kv_cache is not None and isinstance(kv_cache, tuple):
            key_cache, value_cache = HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            key_cache = self.k_cache(key, key_cache, slot_mapping)
            value_cache = self.v_cache(value, value_cache, slot_mapping)

        if attn_metadata.is_prompt:
            # Prompt run.
            query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
            kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)

            attn_bias = attn_metadata.attn_bias
            position_bias = None
            # If we have alibi_slopes, incorporate them with
            if (attn_metadata.block_list is None and self.prompt_position_bias is not None
                    and self.alibi_slopes is not None):
                assert attn_bias is not None, \
                        'attn_bias must be set before calling ' \
                        'model.forward with alibi biases'
                slice_1_size = attn_bias.size(-2)
                slice_2_size = attn_bias.size(-1)
                if self.max_seq_len >= max(slice_1_size, slice_2_size):
                    # Using pre-computed prompt_position_bias subset.
                    position_bias = self.prompt_position_bias[:, :, -slice_1_size:, -slice_2_size:]

                else:
                    # For longer sequences than precomputed,
                    # recreate the bias. This is memory inefficient.
                    position_bias = _make_prompt_alibi_bias(
                        alibi_slopes=self.alibi_slopes,
                        seq_len=max(slice_1_size, slice_2_size),
                        dtype=self.alibi_slopes.dtype,
                    )

            block_list = attn_metadata.block_list if attn_metadata \
                and attn_metadata.block_list is not None else None

            common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size)

            if self.sliding_window:
                if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
                    attn_bias = attn_metadata.window_attn_bias
                else:
                    attn_bias = None
                    window_size = (self.sliding_window, 0)
                    common_args['window_size'] = window_size
            if self.is_chunked_attention and \
                hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
                attn_bias = attn_metadata.chunked_attn_bias

            out = ops.prompt_attention(impl=self.prefill_impl,
                                       query=query.view(query_shape),
                                       key=key.view(kv_shape),
                                       value=value.view(kv_shape),
                                       is_causal=True,
                                       attn_bias=attn_bias,
                                       position_bias=position_bias,
                                       valid_seq_lengths=attn_metadata.seq_lens_tensor,
                                       **common_args)

            output = out.reshape(batch_size, seq_len, hidden_size)
        else:
            # Decoding run.
            if self.sliding_window and \
                attn_metadata.window_block_list is not None:
                block_list = attn_metadata.window_block_list
                block_groups = attn_metadata.window_block_groups
                block_mapping = attn_metadata.window_block_mapping
                attn_bias = attn_metadata.window_attn_bias
            elif self.is_chunked_attention and \
                attn_metadata.chunked_block_list is not None:
                block_list = attn_metadata.chunked_block_list
                block_groups = attn_metadata.chunked_block_groups
                block_mapping = attn_metadata.chunked_block_mapping
                attn_bias = attn_metadata.chunked_attn_bias
            else:
                block_list = attn_metadata.block_list
                block_groups = attn_metadata.block_groups
                block_mapping = attn_metadata.block_mapping
                attn_bias = attn_metadata.attn_bias

            self.position_bias = None
            alibi_blocks = getattr(attn_metadata, 'alibi_blocks', None)
            if self.alibi_slopes is not None and alibi_blocks is not None:
                if self.prev_attn is not None:
                    self.position_bias = self.prev_attn.position_bias
                else:
                    # For decoding, compute position bias using alibi_blocks.
                    self.position_bias = _make_decode_alibi_bias(
                        alibi_blocks=alibi_blocks,
                        alibi_slopes=self.alibi_slopes,
                        dtype=self.alibi_slopes.dtype,
                    )

            output = HPUPagedAttention.forward_decode(query=query,
                                                      block_mapping=block_mapping,
                                                      block_bias=attn_bias,
                                                      block_groups=block_groups,
                                                      position_bias=self.position_bias,
                                                      **self.common_attention_args(block_list, key_cache, value_cache,
                                                                                   attn_metadata.block_size))

        return output.view(*output_shape)

    def common_attention_args(self, block_list=None, key_cache=None, value_cache=None, block_size=None):
        return {
            'scale': self.scale,
            'matmul_qk_op': self.matmul_qk,
            'matmul_av_op': self.matmul_av,
            'batch2block_matmul_op': self.batch2block_matmul,
            'block2batch_matmul_op': self.block2batch_matmul,
            'fsdpa_op': self.fused_scaled_dot_product_attention,
            'keys_fetch_func': self.k_cache.fetch_from_cache,
            'values_fetch_func': self.v_cache.fetch_from_cache,
            'softmax_op': self.softmax,
            'block_list': block_list,
            'key_cache': key_cache,
            'value_cache': value_cache,
            'block_size': block_size,
        }

    def forward_encoder_decoder(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata,
        k_scale: float = 1.0,
        v_scale: float = 1.0,
    ) -> torch.Tensor:
        """Forward pass with xFormers and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        batch_size, hidden_size = query.shape

        if attn_metadata.is_prompt:
            batch_size = attn_metadata.num_prefills
            batched_tokens, _ = query.shape
            batched_kv_tokens, _, _ = key.shape
            assert batch_size > 0, ("In prefill stage the num_prefills should be > 0")
            assert batched_tokens % batch_size == 0
            assert batched_kv_tokens % batch_size == 0
            seq_len = batched_tokens // batch_size

        query = query.unsqueeze(1)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        cross_slot_mapping = attn_metadata.cross_slot_mapping.flatten(
        ) if attn_metadata.cross_slot_mapping is not None else None
        if kv_cache is not None and isinstance(kv_cache, tuple):
            key_cache, value_cache = HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            key_cache = self.k_cache(key, key_cache, cross_slot_mapping)
            value_cache = self.v_cache(value, value_cache, cross_slot_mapping)

        if attn_metadata.is_prompt:
            # Prompt run.
            batch_size = attn_metadata.num_prefills

            query_shape = (batch_size, -1, self.num_heads, self.head_size)
            kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)
            out = ops.prompt_attention(impl=self.prefill_impl,
                                       query=query.view(query_shape),
                                       key=key.view(kv_shape),
                                       value=value.view(kv_shape),
                                       attn_bias=None,
                                       is_causal=False,
                                       **self.common_attention_args())
            output = out.reshape(batch_size, seq_len, hidden_size)
        else:
            # Enc/dec cross-attention KVs match encoder sequence length;
            # cross-attention utilizes special "cross" block tables
            block_list = attn_metadata.cross_block_list
            block_mapping = attn_metadata.cross_block_mapping
            block_groups = attn_metadata.cross_block_groups
            attn_bias = attn_metadata.cross_attn_bias
            # Decoding run.
            output = HPUPagedAttention.forward_decode(query=query,
                                                      block_mapping=block_mapping,
                                                      block_bias=attn_bias,
                                                      block_groups=block_groups,
                                                      position_bias=None,
                                                      **self.common_attention_args(block_list, key_cache, value_cache,
                                                                                   attn_metadata.block_size))
        # Reshape the output tensor.
        return output.view(batch_size, -1, hidden_size)

alibi_slopes 实例属性

alibi_slopes = None

attn_type 实例属性

attn_type = attn_type

batch2block_matmul 实例属性

batch2block_matmul = (
    Matmul() if not enable_fp8_attn else FP8Matmul()
)

block2batch_matmul 实例属性

block2batch_matmul = (
    Matmul() if not enable_fp8_attn else FP8Matmul()
)

enable_fp8_attn 实例属性

enable_fp8_attn = (
    kv_cache_dtype == "fp8_inc"
    and get("QUANT_CONFIG", None) is None
)

fused_scaled_dot_product_attention 实例属性

fused_scaled_dot_product_attention = (
    None
    if HPUFusedSDPA is None
    else ModuleFusedSDPA(HPUFusedSDPA)
)

head_size 实例属性

head_size = head_size

is_chunked_attention 实例属性

is_chunked_attention = False

k_cache 实例属性

k_cache = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

kv_cache_dtype 实例属性

kv_cache_dtype = kv_cache_dtype

matmul_av 实例属性

matmul_av = Matmul() if not enable_fp8_attn else FP8Matmul()

matmul_qk 实例属性

matmul_qk = Matmul() if not enable_fp8_attn else FP8Matmul()

num_heads 实例属性

num_heads = num_heads

num_kv_heads 实例属性

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv 实例属性

num_queries_per_kv = num_heads // num_kv_heads

prefill_impl 实例属性

prefill_impl = prompt_attn_impl

prev_attn 实例属性

prev_attn = None

prompt_position_bias 实例属性

prompt_position_bias = None

scale 实例属性

scale = float(scale)

sliding_window 实例属性

sliding_window = sliding_window

softmax 实例属性

softmax = Softmax()

use_contiguous_pa 实例属性

use_contiguous_pa = use_contiguous_pa

use_merged_prefill 实例属性

use_merged_prefill = merged_prefill

v_cache 实例属性

v_cache = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None:
    super(AttentionImpl, self).__init__()
    if kv_sharing_target_layer_name is not None:
        raise NotImplementedError("KV sharing is not currently supported on HPU.")
    if use_irope:
        logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
                            "to global attention for long context.")
    self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
    self.kv_cache_dtype = kv_cache_dtype
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.matmul_qk = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.softmax = Softmax()
    self.matmul_av = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    self.v_cache = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    HPUFusedSDPA = kernels.fsdpa()
    self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
        else ModuleFusedSDPA(HPUFusedSDPA)
    self.prefill_impl = get_config().prompt_attn_impl
    self.use_contiguous_pa = get_config().use_contiguous_pa
    self.use_merged_prefill = get_config().merged_prefill
    if alibi_slopes is not None:
        assert self.prefill_impl != 'flex_impl', \
            'Prefill with Flex Attention not supported with alibi slopes!'
        assert self.prefill_impl != 'fsdpa_impl', \
            'Prefill with FusedSDPA not supported with alibi slopes!'
        assert self.use_contiguous_pa, \
            'Non-contiguous PA not supported with alibi slopes!'

    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.sliding_window = sliding_window
    self.prompt_position_bias = None
    self.prev_attn = None
    self.alibi_slopes = None
    if alibi_slopes is not None:
        slope_tensor_dtype = torch.float32 if \
            get_config().fp32_alibi_biases else torch.bfloat16
        alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=slope_tensor_dtype)
        self.alibi_slopes = alibi_slopes_tensor

    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        raise ValueError(f"Head size {head_size} is not supported by PagedAttention. "
                         f"Supported head sizes are: {supported_head_sizes}.")

    self.attn_type = attn_type
    if (self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_DECODER
            and self.attn_type != AttentionType.ENCODER_ONLY):
        raise NotImplementedError("Encoder self-attention "
                                  "is not implemented for "
                                  "HPUAttentionImpl")

    self.is_chunked_attention = False

_maybe_init_alibi_biases

_maybe_init_alibi_biases(
    max_seq_len, prev_attn: Optional[Module] = None
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _maybe_init_alibi_biases(
    self,
    max_seq_len,
    prev_attn: Optional[torch.nn.Module] = None,
) -> None:
    self.max_seq_len = max_seq_len
    self.prev_attn = None if prev_attn is None else prev_attn.impl
    if self.alibi_slopes is not None:
        if self.prev_attn is not None:
            self.alibi_slopes = self.prev_attn.alibi_slopes
            self.prompt_position_bias = self.prev_attn.prompt_position_bias
        else:
            # Creating the prompt_position_bias once and reusing it
            # if seq_len permits.
            self.prompt_position_bias = _make_prompt_alibi_bias(
                alibi_slopes=self.alibi_slopes,
                seq_len=self.max_seq_len,
                dtype=self.alibi_slopes.dtype,
            )

common_attention_args

common_attention_args(
    block_list=None,
    key_cache=None,
    value_cache=None,
    block_size=None,
)
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def common_attention_args(self, block_list=None, key_cache=None, value_cache=None, block_size=None):
    return {
        'scale': self.scale,
        'matmul_qk_op': self.matmul_qk,
        'matmul_av_op': self.matmul_av,
        'batch2block_matmul_op': self.batch2block_matmul,
        'block2batch_matmul_op': self.block2batch_matmul,
        'fsdpa_op': self.fused_scaled_dot_product_attention,
        'keys_fetch_func': self.k_cache.fetch_from_cache,
        'values_fetch_func': self.v_cache.fetch_from_cache,
        'softmax_op': self.softmax,
        'block_list': block_list,
        'key_cache': key_cache,
        'value_cache': value_cache,
        'block_size': block_size,
    }

forward

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[Tensor] = None,
) -> Tensor

使用 PagedAttention 进行前向传播。

参数

名称 类型 描述 默认值
query Tensor

shape = [num_tokens, num_heads * head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
attn_metadata HPUAttentionMetadata

注意力元数据。

required

返回:shape = [num_tokens, num_heads * head_size]

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with PagedAttention.

    Args:
        query: shape = [num_tokens, num_heads * head_size]
        key: shape = [num_tokens, num_kv_heads * head_size]
        value: shape = [num_tokens, num_kv_heads * head_size]
        kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
    if self.attn_type == AttentionType.ENCODER_DECODER:
        return self.forward_encoder_decoder(
            query=query,
            key=key,
            value=value,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
            k_scale=layer._k_scale_float,
            v_scale=layer._k_scale_float,
        )
    # Set return shape
    output_shape = query.shape
    if query.dim() == 2:
        if attn_metadata.seq_lens_tensor is not None:
            batch_size = attn_metadata.seq_lens_tensor.shape[0] if not self.use_merged_prefill else 1
        else:
            assert attn_metadata.block_mapping is not None, \
                "seq_lens_tensor must be provided for attention"
            batch_size = attn_metadata.block_mapping.shape[1]
        num_tokens, hidden_size = query.shape
        seq_len = num_tokens // batch_size
        query = query.view(batch_size, seq_len, -1)
    else:
        batch_size, seq_len, hidden_size = query.shape

    seq_len_kv = key.shape[0] // batch_size if key.dim() == 2 else key.shape[1]

    key = key.view(-1, self.num_kv_heads, self.head_size)
    value = value.view(-1, self.num_kv_heads, self.head_size)
    slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None
    key_cache = None
    value_cache = None
    if kv_cache is not None and isinstance(kv_cache, tuple):
        key_cache, value_cache = HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

        # Reshape the input keys and values and store them in the cache.
        # If kv_cache is not provided, the new key and value tensors are
        # not cached. This happens during the initial memory profiling run.
        key_cache = self.k_cache(key, key_cache, slot_mapping)
        value_cache = self.v_cache(value, value_cache, slot_mapping)

    if attn_metadata.is_prompt:
        # Prompt run.
        query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
        kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)

        attn_bias = attn_metadata.attn_bias
        position_bias = None
        # If we have alibi_slopes, incorporate them with
        if (attn_metadata.block_list is None and self.prompt_position_bias is not None
                and self.alibi_slopes is not None):
            assert attn_bias is not None, \
                    'attn_bias must be set before calling ' \
                    'model.forward with alibi biases'
            slice_1_size = attn_bias.size(-2)
            slice_2_size = attn_bias.size(-1)
            if self.max_seq_len >= max(slice_1_size, slice_2_size):
                # Using pre-computed prompt_position_bias subset.
                position_bias = self.prompt_position_bias[:, :, -slice_1_size:, -slice_2_size:]

            else:
                # For longer sequences than precomputed,
                # recreate the bias. This is memory inefficient.
                position_bias = _make_prompt_alibi_bias(
                    alibi_slopes=self.alibi_slopes,
                    seq_len=max(slice_1_size, slice_2_size),
                    dtype=self.alibi_slopes.dtype,
                )

        block_list = attn_metadata.block_list if attn_metadata \
            and attn_metadata.block_list is not None else None

        common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size)

        if self.sliding_window:
            if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
                attn_bias = attn_metadata.window_attn_bias
            else:
                attn_bias = None
                window_size = (self.sliding_window, 0)
                common_args['window_size'] = window_size
        if self.is_chunked_attention and \
            hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
            attn_bias = attn_metadata.chunked_attn_bias

        out = ops.prompt_attention(impl=self.prefill_impl,
                                   query=query.view(query_shape),
                                   key=key.view(kv_shape),
                                   value=value.view(kv_shape),
                                   is_causal=True,
                                   attn_bias=attn_bias,
                                   position_bias=position_bias,
                                   valid_seq_lengths=attn_metadata.seq_lens_tensor,
                                   **common_args)

        output = out.reshape(batch_size, seq_len, hidden_size)
    else:
        # Decoding run.
        if self.sliding_window and \
            attn_metadata.window_block_list is not None:
            block_list = attn_metadata.window_block_list
            block_groups = attn_metadata.window_block_groups
            block_mapping = attn_metadata.window_block_mapping
            attn_bias = attn_metadata.window_attn_bias
        elif self.is_chunked_attention and \
            attn_metadata.chunked_block_list is not None:
            block_list = attn_metadata.chunked_block_list
            block_groups = attn_metadata.chunked_block_groups
            block_mapping = attn_metadata.chunked_block_mapping
            attn_bias = attn_metadata.chunked_attn_bias
        else:
            block_list = attn_metadata.block_list
            block_groups = attn_metadata.block_groups
            block_mapping = attn_metadata.block_mapping
            attn_bias = attn_metadata.attn_bias

        self.position_bias = None
        alibi_blocks = getattr(attn_metadata, 'alibi_blocks', None)
        if self.alibi_slopes is not None and alibi_blocks is not None:
            if self.prev_attn is not None:
                self.position_bias = self.prev_attn.position_bias
            else:
                # For decoding, compute position bias using alibi_blocks.
                self.position_bias = _make_decode_alibi_bias(
                    alibi_blocks=alibi_blocks,
                    alibi_slopes=self.alibi_slopes,
                    dtype=self.alibi_slopes.dtype,
                )

        output = HPUPagedAttention.forward_decode(query=query,
                                                  block_mapping=block_mapping,
                                                  block_bias=attn_bias,
                                                  block_groups=block_groups,
                                                  position_bias=self.position_bias,
                                                  **self.common_attention_args(block_list, key_cache, value_cache,
                                                                               attn_metadata.block_size))

    return output.view(*output_shape)

forward_encoder_decoder

forward_encoder_decoder(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
    k_scale: float = 1.0,
    v_scale: float = 1.0,
) -> Tensor

使用 xFormers 和 PagedAttention 进行前向传播。

参数

名称 类型 描述 默认值
query Tensor

shape = [num_tokens, num_heads * head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
attn_metadata HPUAttentionMetadata

注意力元数据。

required

返回:shape = [num_tokens, num_heads * head_size]

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward_encoder_decoder(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: HPUAttentionMetadata,
    k_scale: float = 1.0,
    v_scale: float = 1.0,
) -> torch.Tensor:
    """Forward pass with xFormers and PagedAttention.

    Args:
        query: shape = [num_tokens, num_heads * head_size]
        key: shape = [num_tokens, num_kv_heads * head_size]
        value: shape = [num_tokens, num_kv_heads * head_size]
        kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    batch_size, hidden_size = query.shape

    if attn_metadata.is_prompt:
        batch_size = attn_metadata.num_prefills
        batched_tokens, _ = query.shape
        batched_kv_tokens, _, _ = key.shape
        assert batch_size > 0, ("In prefill stage the num_prefills should be > 0")
        assert batched_tokens % batch_size == 0
        assert batched_kv_tokens % batch_size == 0
        seq_len = batched_tokens // batch_size

    query = query.unsqueeze(1)
    if key is not None:
        assert value is not None
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
    else:
        assert value is None

    cross_slot_mapping = attn_metadata.cross_slot_mapping.flatten(
    ) if attn_metadata.cross_slot_mapping is not None else None
    if kv_cache is not None and isinstance(kv_cache, tuple):
        key_cache, value_cache = HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

        # Reshape the input keys and values and store them in the cache.
        # If kv_cache is not provided, the new key and value tensors are
        # not cached. This happens during the initial memory profiling run.
        key_cache = self.k_cache(key, key_cache, cross_slot_mapping)
        value_cache = self.v_cache(value, value_cache, cross_slot_mapping)

    if attn_metadata.is_prompt:
        # Prompt run.
        batch_size = attn_metadata.num_prefills

        query_shape = (batch_size, -1, self.num_heads, self.head_size)
        kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)
        out = ops.prompt_attention(impl=self.prefill_impl,
                                   query=query.view(query_shape),
                                   key=key.view(kv_shape),
                                   value=value.view(kv_shape),
                                   attn_bias=None,
                                   is_causal=False,
                                   **self.common_attention_args())
        output = out.reshape(batch_size, seq_len, hidden_size)
    else:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        block_list = attn_metadata.cross_block_list
        block_mapping = attn_metadata.cross_block_mapping
        block_groups = attn_metadata.cross_block_groups
        attn_bias = attn_metadata.cross_attn_bias
        # Decoding run.
        output = HPUPagedAttention.forward_decode(query=query,
                                                  block_mapping=block_mapping,
                                                  block_bias=attn_bias,
                                                  block_groups=block_groups,
                                                  position_bias=None,
                                                  **self.common_attention_args(block_list, key_cache, value_cache,
                                                                               attn_metadata.block_size))
    # Reshape the output tensor.
    return output.view(batch_size, -1, hidden_size)

HPUAttentionMetadata 数据类

基类:HPUPagedAttentionMetadata, AttentionMetadata

HPUAttentionBackend 的元数据。

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
    """Metadata for HPUAttentionbackend."""
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
    block_size: int
    slot_mapping: torch.Tensor
    attn_bias: Optional[torch.Tensor]
    seq_lens_tensor: Optional[torch.Tensor]
    context_lens_tensor: Optional[torch.Tensor]
    input_positions: torch.Tensor
    seq_lens: Optional[list[int]] = None
    encoder_seq_lens: Optional[list[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
    max_encoder_seq_len: Optional[int] = None
    cross_block_list: Optional[torch.Tensor] = None
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_mapping: Optional[torch.Tensor] = None
    cross_block_groups: Optional[torch.Tensor] = None
    cross_block_usage: Optional[torch.Tensor] = None
    cross_attn_bias: Optional[torch.Tensor] = None
    window_block_list: Optional[torch.Tensor] = None
    window_slot_mapping: Optional[torch.Tensor] = None
    window_block_mapping: Optional[torch.Tensor] = None
    window_block_groups: Optional[torch.Tensor] = None
    window_block_usage: Optional[torch.Tensor] = None
    window_attn_bias: Optional[torch.Tensor] = None
    chunked_slot_mapping: Optional[torch.Tensor] = None
    chunked_attn_bias: Optional[torch.Tensor] = None
    chunked_block_mapping: Optional[torch.Tensor] = None
    chunked_block_list: Optional[torch.Tensor] = None
    chunked_block_groups: Optional[torch.Tensor] = None
    chunked_block_usage: Optional[torch.Tensor] = None

attn_bias 实例属性

attn_bias: Optional[Tensor]

block_size 实例属性

block_size: int

chunked_attn_bias 类属性 实例属性

chunked_attn_bias: Optional[Tensor] = None

chunked_block_groups 类属性 实例属性

chunked_block_groups: Optional[Tensor] = None

chunked_block_list 类属性 实例属性

chunked_block_list: Optional[Tensor] = None

chunked_block_mapping 类属性 实例属性

chunked_block_mapping: Optional[Tensor] = None

chunked_block_usage 类属性 实例属性

chunked_block_usage: Optional[Tensor] = None

chunked_slot_mapping 类属性 实例属性

chunked_slot_mapping: Optional[Tensor] = None

context_lens_tensor 实例属性

context_lens_tensor: Optional[Tensor]

cross_attn_bias 类属性 实例属性

cross_attn_bias: Optional[Tensor] = None

cross_block_groups 类属性 实例属性

cross_block_groups: Optional[Tensor] = None

cross_block_list 类属性 实例属性

cross_block_list: Optional[Tensor] = None

cross_block_mapping 类属性 实例属性

cross_block_mapping: Optional[Tensor] = None

cross_block_usage 类属性 实例属性

cross_block_usage: Optional[Tensor] = None

cross_slot_mapping 类属性 实例属性

cross_slot_mapping: Optional[Tensor] = None

encoder_seq_lens 类属性 实例属性

encoder_seq_lens: Optional[list[int]] = None

encoder_seq_lens_tensor 类属性 实例属性

encoder_seq_lens_tensor: Optional[Tensor] = None

input_positions 实例属性

input_positions: Tensor

is_prompt 实例属性

is_prompt: bool

max_encoder_seq_len 类属性 实例属性

max_encoder_seq_len: Optional[int] = None

seq_lens 类属性 实例属性

seq_lens: Optional[list[int]] = None

seq_lens_tensor 实例属性

seq_lens_tensor: Optional[Tensor]

slot_mapping 实例属性

slot_mapping: Tensor

window_attn_bias 类属性 实例属性

window_attn_bias: Optional[Tensor] = None

window_block_groups 类属性 实例属性

window_block_groups: Optional[Tensor] = None

window_block_list 类属性 实例属性

window_block_list: Optional[Tensor] = None

window_block_mapping 类属性 实例属性

window_block_mapping: Optional[Tensor] = None

window_block_usage 类属性 实例属性

window_block_usage: Optional[Tensor] = None

window_slot_mapping 类属性 实例属性

window_slot_mapping: Optional[Tensor] = None

__init__

__init__(
    block_list: Optional[Tensor],
    block_mapping: Optional[Tensor],
    block_usage: Optional[Tensor],
    block_groups: Optional[Tensor],
    alibi_blocks: Optional[Tensor],
    is_prompt: bool,
    block_size: int,
    slot_mapping: Tensor,
    attn_bias: Optional[Tensor],
    seq_lens_tensor: Optional[Tensor],
    context_lens_tensor: Optional[Tensor],
    input_positions: Tensor,
    seq_lens: Optional[list[int]] = None,
    encoder_seq_lens: Optional[list[int]] = None,
    encoder_seq_lens_tensor: Optional[Tensor] = None,
    max_encoder_seq_len: Optional[int] = None,
    cross_block_list: Optional[Tensor] = None,
    cross_slot_mapping: Optional[Tensor] = None,
    cross_block_mapping: Optional[Tensor] = None,
    cross_block_groups: Optional[Tensor] = None,
    cross_block_usage: Optional[Tensor] = None,
    cross_attn_bias: Optional[Tensor] = None,
    window_block_list: Optional[Tensor] = None,
    window_slot_mapping: Optional[Tensor] = None,
    window_block_mapping: Optional[Tensor] = None,
    window_block_groups: Optional[Tensor] = None,
    window_block_usage: Optional[Tensor] = None,
    window_attn_bias: Optional[Tensor] = None,
    chunked_slot_mapping: Optional[Tensor] = None,
    chunked_attn_bias: Optional[Tensor] = None,
    chunked_block_mapping: Optional[Tensor] = None,
    chunked_block_list: Optional[Tensor] = None,
    chunked_block_groups: Optional[Tensor] = None,
    chunked_block_usage: Optional[Tensor] = None,
) -> None

HPUMLAAttentionBackend

基类:HPUAttentionBackend

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@register_backend(AttentionBackendEnum.CUSTOM, "HPU_MLA")
class HPUMLAAttentionBackend(HPUAttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "CUSTOM"

    @staticmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        return HPUMLAImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return HPUMLAMetadata

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return (num_blocks * block_size, head_size)

get_impl_cls 静态方法

get_impl_cls() -> type[AttentionImpl]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
    return HPUMLAImpl

get_kv_cache_shape 静态方法

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return (num_blocks * block_size, head_size)

get_metadata_cls 静态方法

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return HPUMLAMetadata

get_name 静态方法

get_name() -> str
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_name() -> str:
    return "CUSTOM"

HPUMLAImpl

基类:MLACommonImpl[HPUAttentionMetadata], Module

Source code in vllm_gaudi/attention/backends/hpu_attn.py
class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
        **kwargs,
    ) -> None:
        torch.nn.Module.__init__(self)

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.kv_b_proj = kv_b_proj

        # NOTE(kzawora): restore this once https://github.com/vllm-project/vllm/pull/25385 is merged
        #MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window,
        #                       kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **kwargs)

        self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
        self.matmul_qk = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.softmax = Softmax()
        self.matmul_av = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        self.fused_scaled_dot_product_attention = kernels.fsdpa()
        self.use_merged_prefill = get_config().merged_prefill
        self.prefill_impl = get_config().prompt_attn_impl
        assert self.prefill_impl != 'fsdpa_impl' or alibi_slopes is None, \
            'Prefill with FusedSDPA not supported with alibi slopes!'
        self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError("HPUMLAImpl does not support one of the following: "
                                      "alibi_slopes, sliding_window, "
                                      "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TritonMLAImpl")

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if output is not None:
            raise NotImplementedError("output is not yet supported for MLAImplBase")

        is_prefill = attn_metadata.is_prompt

        if not is_prefill:
            # decode
            q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
            # Convert from (B, N, P) to (N, B, P)
            q_nope = q_nope.transpose(0, 1)
            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            decode_ql_nope = torch.bmm(q_nope, self.W_UK_T)
            # Convert from (N, B, L) to (B, N, L)
            decode_ql_nope = decode_ql_nope.transpose(0, 1)

        slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None

        latent_vec_k = torch.concat((k_c_normed, k_pe.view(*k_c_normed.shape[:-1], self.qk_rope_head_dim)), dim=-1)
        latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)

        # write the latent and rope to kv cache
        if kv_cache is not None and len(kv_cache) == 2:
            self.latent_cache_k(latent_vec_k, kv_cache[0], slot_mapping)
            k_cache = kv_cache[0]

        if is_prefill:
            return self._forward_prefill(q, latent_vec_k, k_cache, attn_metadata)
        else:
            return self._forward_decode(decode_ql_nope, q_pe, k_cache, attn_metadata)

    def _forward_prefill(  # type: ignore
            self, q: torch.Tensor, latent_vec_k: torch.Tensor, k_cache: torch.Tensor,
            attn_metadata: HPUAttentionMetadata) -> torch.Tensor:

        ##### get prefix cache #####
        if attn_metadata.block_list is not None:
            current = latent_vec_k
            past = self.latent_cache_k.fetch_from_cache(k_cache.unflatten(0, (-1, attn_metadata.block_size)),
                                                        attn_metadata.block_list)
            past = past.view(-1, past.shape[-1])
            current = torch.concat((past, current), dim=0)
            latent_vec_k = current
        # =========================== #

        k_c_normed, k_pe = latent_vec_k.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

        kv_nope = self.kv_b_proj(k_c_normed)[0]\
            .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

        if not self.use_merged_prefill:
            assert attn_metadata.seq_lens_tensor is not None, \
                "seq_lens_tensor must be provided for prefill attention"
            batch_size = attn_metadata.seq_lens_tensor.shape[0]
        else:
            batch_size = 1
        q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
        k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
        v = v.view(batch_size, -1, self.num_heads, self.v_head_dim)

        to_pad = self.qk_head_dim - self.v_head_dim
        if to_pad > 0:
            v_padding = torch.zeros(*v.shape[:-1], q.shape[-1] - v.shape[-1], device=v.device, dtype=v.dtype)
            v_padded = torch.cat((v, v_padding), dim=-1)
        else:
            v_padded = v

        output = ops.prompt_attention(
            impl=self.prefill_impl,
            query=q,
            key=k,
            value=v_padded,
            is_causal=True,
            attn_bias=attn_metadata.attn_bias,
            position_bias=None,
            valid_seq_lengths=attn_metadata.seq_lens_tensor,
            scale=self.scale,
            matmul_qk_op=self.matmul_qk,
            softmax_op=self.softmax,
            matmul_av_op=self.matmul_av,
            keys_fetch_func=self.latent_cache_k.fetch_from_cache,
            values_fetch_func = None,
            fsdpa_op=self.fused_scaled_dot_product_attention.apply \
            if self.fused_scaled_dot_product_attention is not None else None)
        # remove padding
        output = output.view(batch_size, -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]

        return output.reshape(-1, self.num_heads * v.shape[-1])

    def _forward_decode(  # type: ignore
            self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor,
            attn_metadata: HPUAttentionMetadata) -> torch.Tensor:
        query = torch.cat([q_nope, q_pe], dim=-1)
        key_cache = k_cache.unsqueeze(1)
        value_cache = None
        output = HPUPagedAttention.forward_decode(query=query,
                                                  key_cache=key_cache,
                                                  value_cache=value_cache,
                                                  block_list=attn_metadata.block_list,
                                                  block_mapping=attn_metadata.block_mapping,
                                                  block_bias=attn_metadata.attn_bias,
                                                  block_groups=attn_metadata.block_groups,
                                                  block_size=attn_metadata.block_size,
                                                  scale=self.scale,
                                                  matmul_qk_op=self.matmul_qk,
                                                  matmul_av_op=self.matmul_av,
                                                  batch2block_matmul_op=self.batch2block_matmul,
                                                  block2batch_matmul_op=self.block2batch_matmul,
                                                  keys_fetch_func=self.latent_cache_k.fetch_from_cache,
                                                  values_fetch_func=None,
                                                  kv_lora_rank=self.kv_lora_rank)
        result = self._v_up_proj(output)
        return result

    # NOTE(Chendi): PR25184 using output buffer as default, which can't be used in HPU Graph,
    # so we override and always return a new tensor
    def _v_up_proj(self, x):
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
        x = torch.bmm(x, self.W_UV)
        # Convert from (N, B, V) to (B, N * V)
        x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
        return x

batch2block_matmul 实例属性

batch2block_matmul = (
    Matmul() if not enable_fp8_attn else FP8Matmul()
)

block2batch_matmul 实例属性

block2batch_matmul = (
    Matmul() if not enable_fp8_attn else FP8Matmul()
)

enable_fp8_attn 实例属性

enable_fp8_attn = (
    kv_cache_dtype == "fp8_inc"
    and get("QUANT_CONFIG", None) is None
)

fused_scaled_dot_product_attention 实例属性

fused_scaled_dot_product_attention = fsdpa()

head_size 实例属性

head_size = head_size

is_aiter_triton_fp8_bmm_enabled 实例属性

is_aiter_triton_fp8_bmm_enabled = is_fp8bmm_enabled()

kv_b_proj 实例属性

kv_b_proj = kv_b_proj

kv_cache_dtype 实例属性

kv_cache_dtype = kv_cache_dtype

kv_lora_rank 实例属性

kv_lora_rank = kv_lora_rank

latent_cache_k 实例属性

latent_cache_k = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

matmul_av 实例属性

matmul_av = Matmul() if not enable_fp8_attn else FP8Matmul()

matmul_qk 实例属性

matmul_qk = Matmul() if not enable_fp8_attn else FP8Matmul()

num_heads 实例属性

num_heads = num_heads

num_kv_heads 实例属性

num_kv_heads = num_kv_heads

prefill_impl 实例属性

prefill_impl = prompt_attn_impl

q_lora_rank 实例属性

q_lora_rank = q_lora_rank

qk_head_dim 实例属性

qk_head_dim = qk_head_dim

qk_nope_head_dim 实例属性

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim 实例属性

qk_rope_head_dim = qk_rope_head_dim

scale 实例属性

scale = float(scale)

softmax 实例属性

softmax = Softmax()

use_merged_prefill 实例属性

use_merged_prefill = merged_prefill

v_head_dim 实例属性

v_head_dim = v_head_dim

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    **kwargs,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    # MLA Specific Arguments
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    **kwargs,
) -> None:
    torch.nn.Module.__init__(self)

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    self.kv_cache_dtype = kv_cache_dtype

    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_head_dim
    self.v_head_dim = v_head_dim
    self.kv_b_proj = kv_b_proj

    # NOTE(kzawora): restore this once https://github.com/vllm-project/vllm/pull/25385 is merged
    #MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window,
    #                       kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **kwargs)

    self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
    self.matmul_qk = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.softmax = Softmax()
    self.matmul_av = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.batch2block_matmul = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.block2batch_matmul = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    self.fused_scaled_dot_product_attention = kernels.fsdpa()
    self.use_merged_prefill = get_config().merged_prefill
    self.prefill_impl = get_config().prompt_attn_impl
    assert self.prefill_impl != 'fsdpa_impl' or alibi_slopes is None, \
        'Prefill with FusedSDPA not supported with alibi slopes!'
    self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()

    unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
    if any(unsupported_features):
        raise NotImplementedError("HPUMLAImpl does not support one of the following: "
                                  "alibi_slopes, sliding_window, "
                                  "logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "TritonMLAImpl")

_forward_decode

_forward_decode(
    q_nope: Tensor,
    q_pe: Tensor,
    k_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
) -> Tensor
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _forward_decode(  # type: ignore
        self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata) -> torch.Tensor:
    query = torch.cat([q_nope, q_pe], dim=-1)
    key_cache = k_cache.unsqueeze(1)
    value_cache = None
    output = HPUPagedAttention.forward_decode(query=query,
                                              key_cache=key_cache,
                                              value_cache=value_cache,
                                              block_list=attn_metadata.block_list,
                                              block_mapping=attn_metadata.block_mapping,
                                              block_bias=attn_metadata.attn_bias,
                                              block_groups=attn_metadata.block_groups,
                                              block_size=attn_metadata.block_size,
                                              scale=self.scale,
                                              matmul_qk_op=self.matmul_qk,
                                              matmul_av_op=self.matmul_av,
                                              batch2block_matmul_op=self.batch2block_matmul,
                                              block2batch_matmul_op=self.block2batch_matmul,
                                              keys_fetch_func=self.latent_cache_k.fetch_from_cache,
                                              values_fetch_func=None,
                                              kv_lora_rank=self.kv_lora_rank)
    result = self._v_up_proj(output)
    return result

_forward_prefill

_forward_prefill(
    q: Tensor,
    latent_vec_k: Tensor,
    k_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
) -> Tensor
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _forward_prefill(  # type: ignore
        self, q: torch.Tensor, latent_vec_k: torch.Tensor, k_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata) -> torch.Tensor:

    ##### get prefix cache #####
    if attn_metadata.block_list is not None:
        current = latent_vec_k
        past = self.latent_cache_k.fetch_from_cache(k_cache.unflatten(0, (-1, attn_metadata.block_size)),
                                                    attn_metadata.block_list)
        past = past.view(-1, past.shape[-1])
        current = torch.concat((past, current), dim=0)
        latent_vec_k = current
    # =========================== #

    k_c_normed, k_pe = latent_vec_k.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

    kv_nope = self.kv_b_proj(k_c_normed)[0]\
        .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
    k_nope, v = kv_nope\
        .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

    if not self.use_merged_prefill:
        assert attn_metadata.seq_lens_tensor is not None, \
            "seq_lens_tensor must be provided for prefill attention"
        batch_size = attn_metadata.seq_lens_tensor.shape[0]
    else:
        batch_size = 1
    q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
    k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
    v = v.view(batch_size, -1, self.num_heads, self.v_head_dim)

    to_pad = self.qk_head_dim - self.v_head_dim
    if to_pad > 0:
        v_padding = torch.zeros(*v.shape[:-1], q.shape[-1] - v.shape[-1], device=v.device, dtype=v.dtype)
        v_padded = torch.cat((v, v_padding), dim=-1)
    else:
        v_padded = v

    output = ops.prompt_attention(
        impl=self.prefill_impl,
        query=q,
        key=k,
        value=v_padded,
        is_causal=True,
        attn_bias=attn_metadata.attn_bias,
        position_bias=None,
        valid_seq_lengths=attn_metadata.seq_lens_tensor,
        scale=self.scale,
        matmul_qk_op=self.matmul_qk,
        softmax_op=self.softmax,
        matmul_av_op=self.matmul_av,
        keys_fetch_func=self.latent_cache_k.fetch_from_cache,
        values_fetch_func = None,
        fsdpa_op=self.fused_scaled_dot_product_attention.apply \
        if self.fused_scaled_dot_product_attention is not None else None)
    # remove padding
    output = output.view(batch_size, -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]

    return output.reshape(-1, self.num_heads * v.shape[-1])

_v_up_proj

_v_up_proj(x)
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _v_up_proj(self, x):
    # Convert from (B, N, L) to (N, B, L)
    x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
    # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
    x = torch.bmm(x, self.W_UV)
    # Convert from (N, B, V) to (B, N * V)
    x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
    return x

forward

forward(
    layer: AttentionLayer,
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[Tensor] = None,
) -> Tensor
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward(
    self,
    layer: AttentionLayer,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if output is not None:
        raise NotImplementedError("output is not yet supported for MLAImplBase")

    is_prefill = attn_metadata.is_prompt

    if not is_prefill:
        # decode
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # Convert from (B, N, P) to (N, B, P)
        q_nope = q_nope.transpose(0, 1)
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        decode_ql_nope = torch.bmm(q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        decode_ql_nope = decode_ql_nope.transpose(0, 1)

    slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None

    latent_vec_k = torch.concat((k_c_normed, k_pe.view(*k_c_normed.shape[:-1], self.qk_rope_head_dim)), dim=-1)
    latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)

    # write the latent and rope to kv cache
    if kv_cache is not None and len(kv_cache) == 2:
        self.latent_cache_k(latent_vec_k, kv_cache[0], slot_mapping)
        k_cache = kv_cache[0]

    if is_prefill:
        return self._forward_prefill(q, latent_vec_k, k_cache, attn_metadata)
    else:
        return self._forward_decode(decode_ql_nope, q_pe, k_cache, attn_metadata)

HPUMLAMetadata 数据类

Bases: HPUAttentionMetadata, AttentionMetadata

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@dataclass
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
    pass

__init__

__init__(
    block_list: Optional[Tensor],
    block_mapping: Optional[Tensor],
    block_usage: Optional[Tensor],
    block_groups: Optional[Tensor],
    alibi_blocks: Optional[Tensor],
    is_prompt: bool,
    block_size: int,
    slot_mapping: Tensor,
    attn_bias: Optional[Tensor],
    seq_lens_tensor: Optional[Tensor],
    context_lens_tensor: Optional[Tensor],
    input_positions: Tensor,
    seq_lens: Optional[list[int]] = None,
    encoder_seq_lens: Optional[list[int]] = None,
    encoder_seq_lens_tensor: Optional[Tensor] = None,
    max_encoder_seq_len: Optional[int] = None,
    cross_block_list: Optional[Tensor] = None,
    cross_slot_mapping: Optional[Tensor] = None,
    cross_block_mapping: Optional[Tensor] = None,
    cross_block_groups: Optional[Tensor] = None,
    cross_block_usage: Optional[Tensor] = None,
    cross_attn_bias: Optional[Tensor] = None,
    window_block_list: Optional[Tensor] = None,
    window_slot_mapping: Optional[Tensor] = None,
    window_block_mapping: Optional[Tensor] = None,
    window_block_groups: Optional[Tensor] = None,
    window_block_usage: Optional[Tensor] = None,
    window_attn_bias: Optional[Tensor] = None,
    chunked_slot_mapping: Optional[Tensor] = None,
    chunked_attn_bias: Optional[Tensor] = None,
    chunked_block_mapping: Optional[Tensor] = None,
    chunked_block_list: Optional[Tensor] = None,
    chunked_block_groups: Optional[Tensor] = None,
    chunked_block_usage: Optional[Tensor] = None,
) -> None

HPUUnifiedAttentionBackend

基类:HPUAttentionBackend

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@register_backend(AttentionBackendEnum.CUSTOM, "HPU_UA")
class HPUUnifiedAttentionBackend(HPUAttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "CUSTOM"

    @staticmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        return HPUUnifiedAttentionImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return HPUUnifiedAttentionMetadata

get_impl_cls 静态方法

get_impl_cls() -> type[AttentionImpl]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
    return HPUUnifiedAttentionImpl

get_metadata_cls 静态方法

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return HPUUnifiedAttentionMetadata

get_name 静态方法

get_name() -> str
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_name() -> str:
    return "CUSTOM"

HPUUnifiedAttentionImpl

基类:AttentionImpl, Module

Source code in vllm_gaudi/attention/backends/hpu_attn.py
class HPUUnifiedAttentionImpl(AttentionImpl, torch.nn.Module):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
        use_irope: bool = False,
    ) -> None:
        super(AttentionImpl, self).__init__()

        supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            raise ValueError(f"Head size {head_size} is not supported by PagedAttention. "
                             f"Supported head sizes are: {supported_head_sizes}.")

        unsupported_features = {
            'KV sharing': kv_sharing_target_layer_name is not None,
            'Alibi': alibi_slopes is not None,
            'Sliding window': sliding_window is not None,
            'non-GQA attention': num_kv_heads is None,
            'Encoder attn': attn_type != AttentionType.DECODER,
            'fp32 softmax': get_config().fp32_softmax,
        }
        for feature, check in unsupported_features.items():
            if check:
                raise NotImplementedError(feature + ' is not implemented for HPU unified attn')

        if use_irope:
            logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
                                "to global attention for long context.")
        self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
        self.kv_cache_dtype = kv_cache_dtype
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        self.v_cache = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: tuple[torch.Tensor, torch.Tensor],
        attn_metadata: HPUUnifiedAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        key_cache, value_cache = kv_cache
        query_shape = query.shape
        if query.dim() == 3:
            query = query.flatten(0, 1)
            key = key.flatten(0, 1)
            value = value.flatten(0, 1)
        query = query.unflatten(-1, (-1, self.head_size))
        key = key.unflatten(-1, (-1, self.head_size))
        value = value.unflatten(-1, (-1, self.head_size))
        key_cache = self.k_cache(key, key_cache, attn_metadata.slot_mapping)
        value_cache = self.v_cache(value, value_cache, attn_metadata.slot_mapping)
        output = unified_attn(
            query=query,
            key=key,
            value=value,
            key_cache=key_cache,
            value_cache=value_cache,
            scale=self.scale,
            metadata=attn_metadata,
        )
        output = output.unflatten(0, (query_shape[0], query_shape[1])).flatten(-2, -1)
        return output

enable_fp8_attn 实例属性

enable_fp8_attn = (
    kv_cache_dtype == "fp8_inc"
    and get("QUANT_CONFIG", None) is None
)

head_size 实例属性

head_size = head_size

k_cache 实例属性

k_cache = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

kv_cache_dtype 实例属性

kv_cache_dtype = kv_cache_dtype

num_heads 实例属性

num_heads = num_heads

num_kv_heads 实例属性

num_kv_heads = num_kv_heads

num_queries_per_kv 实例属性

num_queries_per_kv = num_heads // num_kv_heads

scale 实例属性

scale = float(scale)

v_cache 实例属性

v_cache = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None:
    super(AttentionImpl, self).__init__()

    supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        raise ValueError(f"Head size {head_size} is not supported by PagedAttention. "
                         f"Supported head sizes are: {supported_head_sizes}.")

    unsupported_features = {
        'KV sharing': kv_sharing_target_layer_name is not None,
        'Alibi': alibi_slopes is not None,
        'Sliding window': sliding_window is not None,
        'non-GQA attention': num_kv_heads is None,
        'Encoder attn': attn_type != AttentionType.DECODER,
        'fp32 softmax': get_config().fp32_softmax,
    }
    for feature, check in unsupported_features.items():
        if check:
            raise NotImplementedError(feature + ' is not implemented for HPU unified attn')

    if use_irope:
        logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
                            "to global attention for long context.")
    self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
    self.kv_cache_dtype = kv_cache_dtype
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads
    self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    self.v_cache = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()

forward

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: tuple[Tensor, Tensor],
    attn_metadata: HPUUnifiedAttentionMetadata,
    output: Optional[Tensor] = None,
) -> Tensor
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: tuple[torch.Tensor, torch.Tensor],
    attn_metadata: HPUUnifiedAttentionMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    key_cache, value_cache = kv_cache
    query_shape = query.shape
    if query.dim() == 3:
        query = query.flatten(0, 1)
        key = key.flatten(0, 1)
        value = value.flatten(0, 1)
    query = query.unflatten(-1, (-1, self.head_size))
    key = key.unflatten(-1, (-1, self.head_size))
    value = value.unflatten(-1, (-1, self.head_size))
    key_cache = self.k_cache(key, key_cache, attn_metadata.slot_mapping)
    value_cache = self.v_cache(value, value_cache, attn_metadata.slot_mapping)
    output = unified_attn(
        query=query,
        key=key,
        value=value,
        key_cache=key_cache,
        value_cache=value_cache,
        scale=self.scale,
        metadata=attn_metadata,
    )
    output = output.unflatten(0, (query_shape[0], query_shape[1])).flatten(-2, -1)
    return output

_make_decode_alibi_bias

_make_decode_alibi_bias(
    alibi_blocks: Tensor, alibi_slopes: Tensor, dtype: dtype
) -> Tensor

为解码阶段创建 ALiBi 位置偏置张量。使用存储的 alibi_blocks 和 slopes 进行最终缩放。按块数量缩放,而不是按批次大小缩放。

参数

名称 类型 描述 默认值
alibi_blocks Tensor

形状 = [num_blocks, block_size]

required
alibi_slopes Tensor

形状 = [num_heads]

required
dtype dtype

torch.dtype

required

返回

类型 描述
Tensor

一个形状为 [num_blocks, num_heads, block_size] 的每头偏置张量。

Tensor

每一行都编码了与解码步数相关的、依赖于位置的 ALiBi 斜率。

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _make_decode_alibi_bias(
    alibi_blocks: torch.Tensor,
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Create the ALiBi position bias tensor for decode stage.
    Uses stored alibi_blocks and slopes for final scaling.
    Scales with number of blocks, not with batch size.

    Args:
        alibi_blocks: shape = [num_blocks, block_size]
        alibi_slopes: shape = [num_heads]
        dtype: torch.dtype

    Returns:
        A per-head bias tensor of shape [num_blocks, num_heads, block_size].
        Each row encodes position-dependent ALiBi slopes for decoding steps.
    """
    num_heads = alibi_slopes.shape[0]
    per_head_bias = torch.empty(
        alibi_blocks.size(0),
        num_heads,
        alibi_blocks.size(-1),
        device=alibi_slopes.device,
        dtype=dtype,
    )
    # NOTE(Tanner):
    # .copy_ was not performing broadcasting of bias
    # to all 32 heads in Eager mode.
    per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
    per_head_bias.mul_(alibi_slopes[None, :, None])

    return per_head_bias

_make_prompt_alibi_bias

_make_prompt_alibi_bias(
    alibi_slopes: Tensor, seq_len: int, dtype: dtype
) -> Tensor

为提示阶段创建 ALiBi 位置偏置张量。此张量在使用时会被重复使用或平铺。不按批次大小或块数缩放。

参数

名称 类型 描述 默认值
alibi_slopes Tensor

形状 = [num_heads]

required
seq_len int

int

required
dtype dtype

torch.dtype

required

返回

类型 描述
Tensor

一个形状为 [1, num_heads, seq_len, seq_len] 的每头偏置张量。

Tensor

此偏置通过 ALiBi 斜率编码位置信息。

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _make_prompt_alibi_bias(
    alibi_slopes: torch.Tensor,
    seq_len: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Create the ALiBi position bias tensor for prompt stage.
    This tensor is reused or tiled as needed for each forward pass.
    Does not scale with batch size or number of blocks.

    Args:
        alibi_slopes: shape = [num_heads]
        seq_len: int
        dtype: torch.dtype

    Returns:
        A per-head bias tensor of shape [1, num_heads, seq_len, seq_len].
        This bias encodes positional information via ALiBi slopes.
    """
    # Create the bias matrix for positional differences
    bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
    bias = bias[None, :] - bias[:, None]  # Shape: [seq_len, seq_len]

    #padded_len = (seq_len + 7) // 8 * 8
    num_heads = alibi_slopes.shape[0]
    per_head_bias = torch.empty(
        1,
        num_heads,
        seq_len,
        seq_len,  # Directly use seq_len instead of padded_len
        device=alibi_slopes.device,
        dtype=dtype,
    )

    # Copy the bias matrix into each head
    per_head_bias[:, :] = bias

    # Scale the bias by the ALiBi slopes
    per_head_bias.mul_(alibi_slopes[:, None, None])

    return per_head_bias