@RMSNorm.register_oot
class HPURMSNorm(RMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
from vllm_gaudi.extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if residual is not None:
orig_shape = x.shape
residual = residual + x.view(residual.shape)
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual
x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
return x