跳到内容

vllm_gaudi.ops.hpu_layernorm

HPURMSNorm

基类:RMSNorm

源代码在 vllm_gaudi/ops/hpu_layernorm.py
@RMSNorm.register_oot
class HPURMSNorm(RMSNorm):

    def forward_oot(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        from vllm_gaudi.extension.kernels import rms_norm
        HPUFusedRMSNorm = rms_norm()
        if residual is not None:
            orig_shape = x.shape
            residual = residual + x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

forward_oot

forward_oot(
    x: Tensor, residual: Optional[Tensor] = None
) -> Union[Tensor, tuple[Tensor, Tensor]]
源代码在 vllm_gaudi/ops/hpu_layernorm.py
def forward_oot(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    from vllm_gaudi.extension.kernels import rms_norm
    HPUFusedRMSNorm = rms_norm()
    if residual is not None:
        orig_shape = x.shape
        residual = residual + x.view(residual.shape)
        # Note: HPUFusedRMSNorm requires 3D tensors as inputs
        x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
        return x.view(orig_shape), residual

    x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
    return x