class ConvertScaleToHwAligned:
def __init__(self, device_type="GAUDI3"):
self.device_type = "GAUDI2" if is_hpu_gaudi2 else "GAUDI3"
def calc(self, scale):
if self.device_type == "GAUDI2":
scale = scale * get_hpu_gaudi2_scale_factor()
scale_pow2 = ScaleToPow2().calc(scale)
min_scale, max_scale, scale_factor = FP8_143_SCALES_TRAITS[self.device_type]
scale_pow2_hw = torch.minimum(
torch.maximum(
2.0**(torch.ceil(torch.log2(scale_pow2) / scale_factor) * scale_factor),
torch.tensor(min_scale, dtype=scale.dtype, device=scale.device),
),
torch.tensor(max_scale, dtype=scale.dtype, device=scale.device),
)
return scale_pow2_hw