跳到内容

vllm_gaudi.extension.utils

FP8Matmul

基类: Module

源代码在 vllm_gaudi/extension/utils.py
class FP8Matmul(torch.nn.Module):

    def __init__(
        self,
        scale_input=1.0,
        scale_other=1.0,
    ):
        super().__init__()
        self.scale_input = scale_input
        self.scale_other = scale_other

    def quant_input(self, x, scale):
        return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

    def matmul_fp8(self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None):
        return torch.ops.hpu.fp8_gemm_v2(
            A=x,
            trans_A=False,
            B=other,
            trans_B=False,
            D=None,
            out_dtype=out_dtype,
            A_scale_inv=scale_input_inv,
            B_scale_inv=scale_other_inv,
            bias=None,
            accumulate=False,
        )

    def forward(self, input, other):
        qinput = self.quant_input(input, self.scale_input)
        qother = self.quant_input(other, self.scale_other)
        output = self.matmul_fp8(
            qinput,
            qother,
            out_dtype=torch.bfloat16,
            scale_input_inv=1.0 / self.scale_input,
            scale_other_inv=1.0 / self.scale_other,
        )
        return output

scale_input 实例属性

scale_input = scale_input

scale_other 实例属性

scale_other = scale_other

__init__

__init__(scale_input=1.0, scale_other=1.0)
源代码在 vllm_gaudi/extension/utils.py
def __init__(
    self,
    scale_input=1.0,
    scale_other=1.0,
):
    super().__init__()
    self.scale_input = scale_input
    self.scale_other = scale_other

forward

forward(input, other)
源代码在 vllm_gaudi/extension/utils.py
def forward(self, input, other):
    qinput = self.quant_input(input, self.scale_input)
    qother = self.quant_input(other, self.scale_other)
    output = self.matmul_fp8(
        qinput,
        qother,
        out_dtype=torch.bfloat16,
        scale_input_inv=1.0 / self.scale_input,
        scale_other_inv=1.0 / self.scale_other,
    )
    return output

matmul_fp8

matmul_fp8(
    x,
    other,
    out_dtype,
    scale_input_inv=None,
    scale_other_inv=None,
)
源代码在 vllm_gaudi/extension/utils.py
def matmul_fp8(self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None):
    return torch.ops.hpu.fp8_gemm_v2(
        A=x,
        trans_A=False,
        B=other,
        trans_B=False,
        D=None,
        out_dtype=out_dtype,
        A_scale_inv=scale_input_inv,
        B_scale_inv=scale_other_inv,
        bias=None,
        accumulate=False,
    )

quant_input

quant_input(x, scale)
源代码在 vllm_gaudi/extension/utils.py
def quant_input(self, x, scale):
    return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

Matmul

基类: Module

源代码在 vllm_gaudi/extension/utils.py
class Matmul(torch.nn.Module):

    def __init__(self):
        super(Matmul, self).__init__()

    def forward(self, x, y, **kwargs):
        return torch.matmul(x, y, **kwargs)

__init__

__init__()
源代码在 vllm_gaudi/extension/utils.py
def __init__(self):
    super(Matmul, self).__init__()

forward

forward(x, y, **kwargs)
源代码在 vllm_gaudi/extension/utils.py
def forward(self, x, y, **kwargs):
    return torch.matmul(x, y, **kwargs)

ModuleFusedSDPA

基类: Module

源代码在 vllm_gaudi/extension/utils.py
class ModuleFusedSDPA(torch.nn.Module):

    def __init__(self, fusedSDPA):
        super().__init__()
        assert fusedSDPA is not None, f'fusedSDPA kernel is None'
        self._hpu_kernel_fsdpa = fusedSDPA

    def forward(
        self,
        query,
        key,
        value,
        attn_mask,
        dropout_p,
        is_causal,
        scale,
        softmax_mode,
        recompute_mode,
        valid_sequence_lengths,
        padding_side="left",
        window_size=None,
    ):
        if window_size is not None:
            return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                                recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                                window_size)
        else:
            return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                                recompute_mode, valid_sequence_lengths, padding_side)

_hpu_kernel_fsdpa 实例属性

_hpu_kernel_fsdpa = fusedSDPA

__init__

__init__(fusedSDPA)
源代码在 vllm_gaudi/extension/utils.py
def __init__(self, fusedSDPA):
    super().__init__()
    assert fusedSDPA is not None, f'fusedSDPA kernel is None'
    self._hpu_kernel_fsdpa = fusedSDPA

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
)
源代码在 vllm_gaudi/extension/utils.py
def forward(
    self,
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
):
    if window_size is not None:
        return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                            recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                            window_size)
    else:
        return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                            recompute_mode, valid_sequence_lengths, padding_side)

Softmax

基类: Module

源代码在 vllm_gaudi/extension/utils.py
class Softmax(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, dim=None, inv_head=None):
        return torch.softmax(x, dim)

__init__

__init__()
源代码在 vllm_gaudi/extension/utils.py
def __init__(self):
    super().__init__()

forward

forward(x, dim=None, inv_head=None)
源代码在 vllm_gaudi/extension/utils.py
def forward(self, x, dim=None, inv_head=None):
    return torch.softmax(x, dim)

VLLMFP8KVCache

基类:VLLMKVCache

源代码在 vllm_gaudi/extension/utils.py
class VLLMFP8KVCache(VLLMKVCache):

    def __init__(self, input_scale=1.0):
        super(VLLMKVCache, self).__init__()
        self.use_contiguous_pa = get_config().use_contiguous_pa
        self.input_scale = input_scale
        self.output_scale = 1.0 / self.input_scale

    def quant_input(self, input):
        return torch.ops.hpu.cast_to_fp8_v2(input, self.input_scale, False, False, torch.float8_e4m3fn)[0]

    def dequant_output(self, output):
        return torch.ops.hpu.cast_from_fp8(output, self.output_scale, torch.bfloat16)

    def forward(self, input, *args, **kwargs):
        qinput = self.quant_input(input)
        return super().forward(qinput, *args, **kwargs)

    def fetch_from_cache(self, quant_cache, blocks, permutations=None):
        if permutations:
            output_cache = super().fetch_from_cache(quant_cache, blocks, permutations)
            for i in range(len(output_cache)):
                output_cache[i] = self.dequant_output(output_cache[i])
            return output_cache
        output_cache = super().fetch_from_cache(quant_cache, blocks)
        return self.dequant_output(output_cache)

input_scale 实例属性

input_scale = input_scale

output_scale 实例属性

output_scale = 1.0 / input_scale

use_contiguous_pa 实例属性

use_contiguous_pa = use_contiguous_pa

__init__

__init__(input_scale=1.0)
源代码在 vllm_gaudi/extension/utils.py
def __init__(self, input_scale=1.0):
    super(VLLMKVCache, self).__init__()
    self.use_contiguous_pa = get_config().use_contiguous_pa
    self.input_scale = input_scale
    self.output_scale = 1.0 / self.input_scale

dequant_output

dequant_output(output)
源代码在 vllm_gaudi/extension/utils.py
def dequant_output(self, output):
    return torch.ops.hpu.cast_from_fp8(output, self.output_scale, torch.bfloat16)

fetch_from_cache

fetch_from_cache(quant_cache, blocks, permutations=None)
源代码在 vllm_gaudi/extension/utils.py
def fetch_from_cache(self, quant_cache, blocks, permutations=None):
    if permutations:
        output_cache = super().fetch_from_cache(quant_cache, blocks, permutations)
        for i in range(len(output_cache)):
            output_cache[i] = self.dequant_output(output_cache[i])
        return output_cache
    output_cache = super().fetch_from_cache(quant_cache, blocks)
    return self.dequant_output(output_cache)

forward

forward(input, *args, **kwargs)
源代码在 vllm_gaudi/extension/utils.py
def forward(self, input, *args, **kwargs):
    qinput = self.quant_input(input)
    return super().forward(qinput, *args, **kwargs)

quant_input

quant_input(input)
源代码在 vllm_gaudi/extension/utils.py
def quant_input(self, input):
    return torch.ops.hpu.cast_to_fp8_v2(input, self.input_scale, False, False, torch.float8_e4m3fn)[0]

VLLMKVCache

基类: Module

源代码在 vllm_gaudi/extension/utils.py
class VLLMKVCache(torch.nn.Module):

    def __init__(self):
        super(VLLMKVCache, self).__init__()
        self.use_contiguous_pa = get_config().use_contiguous_pa

    def forward(self, input, cache, slot_mapping):
        # In cross-attention kv cache forward inputs are None in decode
        # We don't want to store them in the cache in such case
        if input is not None:
            cache.index_copy_(0, slot_mapping, input)
        return cache

    def fetch_from_cache(self, cache, blocks):
        if self.use_contiguous_pa:
            return cache[:blocks.size(0)]
        else:
            return cache.index_select(0, blocks)

use_contiguous_pa 实例属性

use_contiguous_pa = use_contiguous_pa

__init__

__init__()
源代码在 vllm_gaudi/extension/utils.py
def __init__(self):
    super(VLLMKVCache, self).__init__()
    self.use_contiguous_pa = get_config().use_contiguous_pa

fetch_from_cache

fetch_from_cache(cache, blocks)
源代码在 vllm_gaudi/extension/utils.py
def fetch_from_cache(self, cache, blocks):
    if self.use_contiguous_pa:
        return cache[:blocks.size(0)]
    else:
        return cache.index_select(0, blocks)

forward

forward(input, cache, slot_mapping)
源代码在 vllm_gaudi/extension/utils.py
def forward(self, input, cache, slot_mapping):
    # In cross-attention kv cache forward inputs are None in decode
    # We don't want to store them in the cache in such case
    if input is not None:
        cache.index_copy_(0, slot_mapping, input)
    return cache

align_and_pad

align_and_pad(data, bucketing, padding_gen)
源代码在 vllm_gaudi/extension/utils.py
def align_and_pad(data, bucketing, padding_gen):
    bs = len(data)
    target_bs, target_len = bucketing
    if target_bs == 1 and bs > 1:
        data = [list(itertools.chain(*data))]
    data = [pad_list(x, target_len, padding_gen) for x in data]
    padding = itertools.islice(padding_gen, target_len)
    data = pad_list(data, target_bs, itertools.tee(padding, target_bs - len(data)))
    return data

is_fake_hpu 缓存

is_fake_hpu() -> bool
源代码在 vllm_gaudi/extension/utils.py
@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
    return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'

pad_list

pad_list(input, target_len, val_generator)
源代码在 vllm_gaudi/extension/utils.py
def pad_list(input, target_len, val_generator):
    padding = target_len - len(input)
    if padding > 0:
        input.extend(itertools.islice(val_generator, padding))
    return input

with_default

with_default(value: Optional[Any], default: Any) -> Any
源代码在 vllm_gaudi/extension/utils.py
def with_default(value: Optional[Any], default: Any) -> Any:
    if value is not None:
        return value
    return default