跳到内容

vllm_gaudi.extension.kernels

_kernel

_kernel(name)
源代码位于 vllm_gaudi/extension/kernels.py
def _kernel(name):

    def loader(fn):

        @cache
        def loader_impl():
            try:
                return fn()
            except (ImportError, AttributeError):
                from .utils import logger
                logger().warning(f"Could not import HPU {name} kernel. "
                                 "vLLM will use native implementation")

        return loader_impl

    return loader

block_softmax_adjustment

block_softmax_adjustment()
源代码位于 vllm_gaudi/extension/kernels.py
@_kernel("block_softmax_adjustment")
def block_softmax_adjustment():
    import torch
    return torch.ops.hpu.block_softmax_adjustment

fsdpa

fsdpa()
源代码位于 vllm_gaudi/extension/kernels.py
@_kernel("FusedSDPA")
def fsdpa():
    from habana_frameworks.torch.hpex.kernels import FusedSDPA
    return FusedSDPA

rms_norm

rms_norm()
源代码位于 vllm_gaudi/extension/kernels.py
@_kernel("FusedRMSNorm")
def rms_norm():
    from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
    return FusedRMSNorm