跳到内容

vllm_gaudi.ops.hpu_multihead_attn

HpuMultiHeadAttention

基类: MultiHeadAttention

源代码在 vllm_gaudi/ops/hpu_multihead_attn.py
class HpuMultiHeadAttention(MultiHeadAttention):

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: 
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

        query, key, value = (x.transpose(1, 2) for x in (query, key, value))

        from vllm_gaudi.extension.runtime import get_config

        if get_config().prompt_attn_impl == 'fsdpa_impl':

            from vllm_gaudi.extension.utils import ModuleFusedSDPA
            import vllm_gaudi.extension.kernels as kernels

            HPUFusedSDPA = kernels.fsdpa()
            fsdpa_op = ModuleFusedSDPA(HPUFusedSDPA)

            out = fsdpa_op(query,
                           key,
                           value,
                           None,
                           dropout_p=0.0,
                           is_causal=False,
                           scale=self.scale,
                           softmax_mode="fast",
                           recompute_mode=True,
                           valid_sequence_lengths=None)
        else:
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

        out = out.transpose(1, 2)
        return out.reshape(bsz, q_len, -1)

forward

forward(
    query: Tensor, key: Tensor, value: Tensor
) -> Tensor

输入形状: (batch_size x seq_len x hidden_size) 或 (batch_size x seq_len x num_heads x head_size)

源代码在 vllm_gaudi/ops/hpu_multihead_attn.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
) -> torch.Tensor:
    """Input shape: 
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)

    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    query, key, value = (x.transpose(1, 2) for x in (query, key, value))

    from vllm_gaudi.extension.runtime import get_config

    if get_config().prompt_attn_impl == 'fsdpa_impl':

        from vllm_gaudi.extension.utils import ModuleFusedSDPA
        import vllm_gaudi.extension.kernels as kernels

        HPUFusedSDPA = kernels.fsdpa()
        fsdpa_op = ModuleFusedSDPA(HPUFusedSDPA)

        out = fsdpa_op(query,
                       key,
                       value,
                       None,
                       dropout_p=0.0,
                       is_causal=False,
                       scale=self.scale,
                       softmax_mode="fast",
                       recompute_mode=True,
                       valid_sequence_lengths=None)
    else:
        out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

    out = out.transpose(1, 2)
    return out.reshape(bsz, q_len, -1)