跳到内容

vllm_gaudi.ops.hpu_fp8

Fp8LinearMethod

Bases: Fp8LinearMethod

Source code in vllm_gaudi/ops/hpu_fp8.py
class Fp8LinearMethod(OrigFp8LinearMethod):

    def create_weights(self, *args, **kwargs) -> None:
        if hpu_ops.is_hpu_gaudi2:
            kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader'))
        super().create_weights(*args, **kwargs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.quant_config = self.quant_config
        if self.block_quant:
            layer = hpu_ops.fp8_block_linear_postprocess_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8)
            return
        # If checkpoint not serialized fp8, quantize the weights.
        elif not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = hpu_ops.scaled_fp8_quant(layer.weight, scale=None)
            weight = qweight.t()

        # If checkpoint is fp8 per-tensor, handle that there are N scales for N
        # shards in a fused module
        else:
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.

            weight, weight_scale, input_scale = hpu_ops.process_fp8_weight_tensor_strategy(
                weight,
                weight_scale,
                layer.logical_widths,
                getattr(layer, "input_scale", None),
            )
            if self.act_q_static:
                assert input_scale is not None
                input_scale = input_scale.max()
            weight = weight.t()

        # Update layer with new values.
        layer.weight = Parameter(weight.data, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
        layer.input_scale = (Parameter(input_scale, requires_grad=False) if input_scale is not None else None)

    def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            return hpu_ops.apply_block_fp8_linear_hpu(
                input=x,
                layer=layer,
                block_size=self.quant_config.weight_block_size,
                bias=bias,
                do_unpad=True,
                force_channel_fp8=envs.VLLM_HPU_FORCE_CHANNEL_FP8,
            )

        weight_scale = layer.weight_scale.transpose(0, 1) if layer.weight_scale.dim() > 1 else layer.weight_scale
        input_scale = getattr(layer, 'input_scale', None)
        return hpu_ops.apply_fp8_linear_hpu(input=x,
                                            weight=layer.weight,
                                            weight_scale=weight_scale,
                                            input_scale=input_scale,
                                            bias=bias,
                                            trans_B=False)

    def dequant_fp8_weight(self, layer) -> torch.Tensor:
        if hasattr(layer, "updated_fp8_weight") and layer.updated_fp8_weight:
            return layer.weight
        dequant_weight = hpu_ops.dequant_block_fp8_weight_naive(
            layer.weight,
            layer.weight_scale_inv.data,
            self.quant_config.weight_block_size,
            original_M=layer.orig_M,
            original_N=layer.orig_N,
            do_unpad=True,
        )
        return dequant_weight

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm_gaudi/ops/hpu_fp8.py
def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    if self.block_quant:
        assert self.quant_config.weight_block_size is not None
        return hpu_ops.apply_block_fp8_linear_hpu(
            input=x,
            layer=layer,
            block_size=self.quant_config.weight_block_size,
            bias=bias,
            do_unpad=True,
            force_channel_fp8=envs.VLLM_HPU_FORCE_CHANNEL_FP8,
        )

    weight_scale = layer.weight_scale.transpose(0, 1) if layer.weight_scale.dim() > 1 else layer.weight_scale
    input_scale = getattr(layer, 'input_scale', None)
    return hpu_ops.apply_fp8_linear_hpu(input=x,
                                        weight=layer.weight,
                                        weight_scale=weight_scale,
                                        input_scale=input_scale,
                                        bias=bias,
                                        trans_B=False)

create_weights

create_weights(*args, **kwargs) -> None
Source code in vllm_gaudi/ops/hpu_fp8.py
def create_weights(self, *args, **kwargs) -> None:
    if hpu_ops.is_hpu_gaudi2:
        kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader'))
    super().create_weights(*args, **kwargs)

dequant_fp8_weight

dequant_fp8_weight(layer) -> Tensor
Source code in vllm_gaudi/ops/hpu_fp8.py
def dequant_fp8_weight(self, layer) -> torch.Tensor:
    if hasattr(layer, "updated_fp8_weight") and layer.updated_fp8_weight:
        return layer.weight
    dequant_weight = hpu_ops.dequant_block_fp8_weight_naive(
        layer.weight,
        layer.weight_scale_inv.data,
        self.quant_config.weight_block_size,
        original_M=layer.orig_M,
        original_N=layer.orig_N,
        do_unpad=True,
    )
    return dequant_weight

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm_gaudi/ops/hpu_fp8.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    layer.quant_config = self.quant_config
    if self.block_quant:
        layer = hpu_ops.fp8_block_linear_postprocess_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8)
        return
    # If checkpoint not serialized fp8, quantize the weights.
    elif not self.quant_config.is_checkpoint_fp8_serialized:
        qweight, weight_scale = hpu_ops.scaled_fp8_quant(layer.weight, scale=None)
        weight = qweight.t()

    # If checkpoint is fp8 per-tensor, handle that there are N scales for N
    # shards in a fused module
    else:
        weight = layer.weight
        weight_scale = layer.weight_scale

        # If using w8a8, torch._scaled_mm needs per tensor, so
        # requantize the logical shards as a single weight.

        weight, weight_scale, input_scale = hpu_ops.process_fp8_weight_tensor_strategy(
            weight,
            weight_scale,
            layer.logical_widths,
            getattr(layer, "input_scale", None),
        )
        if self.act_q_static:
            assert input_scale is not None
            input_scale = input_scale.max()
        weight = weight.t()

    # Update layer with new values.
    layer.weight = Parameter(weight.data, requires_grad=False)
    layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
    layer.input_scale = (Parameter(input_scale, requires_grad=False) if input_scale is not None else None)

HPUFp8MoEMethod

Bases: Fp8MoEMethod

Source code in vllm_gaudi/ops/hpu_fp8.py
class HPUFp8MoEMethod(Fp8MoEMethod):

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)

        # Disable marlin
        self.use_marlin = False

        # disable DeepGemm support.
        self.allow_deep_gemm = False

    def create_weights(self, *args, **kwargs) -> None:
        if hpu_ops.is_hpu_gaudi2:
            kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader'))
        kwargs['weight_loader'] = hpu_ops.synced_weight_loader(kwargs.get('weight_loader'))
        super().create_weights(*args, **kwargs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        num_experts = layer.local_num_experts
        ep_shift = layer.ep_rank * num_experts

        experts_min, experts_max = ep_shift, num_experts + ep_shift - 1
        if self.block_quant and not envs.VLLM_HPU_FORCE_CHANNEL_FP8:
            layer.moe_op = VllmMixtureOfExpertsOpFP8(
                num_experts,
                experts_min,
                experts_max,
            )
        else:
            layer.moe_op = VllmMixtureOfExpertsOpFP8PerChannel(
                num_experts,
                experts_min,
                experts_max,
            )
        if self.block_quant:
            layer = hpu_ops.fp8_block_moe_prepare_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8)
        else:
            layer = hpu_ops.fp8_channel_moe_prepare_weights(layer)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        **kwargs,
    ) -> torch.Tensor:
        input_shape = x.shape
        x = x.view(-1, x.shape[-1])
        if use_grouped_topk or custom_routing_function is not None:
            topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
                e_score_correction_bias=e_score_correction_bias)
        else:
            import torch.nn.functional as F
            topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
            topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
            topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
            topk_weights = topk_weights.to(x.dtype)
        topk_ids = topk_ids.view(*x.shape[:-1], -1)
        topk_weights = topk_weights.view(*x.shape[:-1], -1)
        output = layer.moe_op(
            x,
            topk_ids.to(torch.int64),
            topk_weights.to(x.dtype),
            permuted_weights=True,
            activation=activation,
        )
        return output.view(*input_shape)

allow_deep_gemm 实例属性

allow_deep_gemm = False

use_marlin 实例属性

use_marlin = False

__init__

__init__(quant_config: Fp8Config, layer: Module)
Source code in vllm_gaudi/ops/hpu_fp8.py
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
    super().__init__(quant_config, layer)

    # Disable marlin
    self.use_marlin = False

    # disable DeepGemm support.
    self.allow_deep_gemm = False

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    **kwargs,
) -> Tensor
Source code in vllm_gaudi/ops/hpu_fp8.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    **kwargs,
) -> torch.Tensor:
    input_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if use_grouped_topk or custom_routing_function is not None:
        topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)
    else:
        import torch.nn.functional as F
        topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
        topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
        topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
        topk_weights = topk_weights.to(x.dtype)
    topk_ids = topk_ids.view(*x.shape[:-1], -1)
    topk_weights = topk_weights.view(*x.shape[:-1], -1)
    output = layer.moe_op(
        x,
        topk_ids.to(torch.int64),
        topk_weights.to(x.dtype),
        permuted_weights=True,
        activation=activation,
    )
    return output.view(*input_shape)

create_weights

create_weights(*args, **kwargs) -> None
Source code in vllm_gaudi/ops/hpu_fp8.py
def create_weights(self, *args, **kwargs) -> None:
    if hpu_ops.is_hpu_gaudi2:
        kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader'))
    kwargs['weight_loader'] = hpu_ops.synced_weight_loader(kwargs.get('weight_loader'))
    super().create_weights(*args, **kwargs)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm_gaudi/ops/hpu_fp8.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    num_experts = layer.local_num_experts
    ep_shift = layer.ep_rank * num_experts

    experts_min, experts_max = ep_shift, num_experts + ep_shift - 1
    if self.block_quant and not envs.VLLM_HPU_FORCE_CHANNEL_FP8:
        layer.moe_op = VllmMixtureOfExpertsOpFP8(
            num_experts,
            experts_min,
            experts_max,
        )
    else:
        layer.moe_op = VllmMixtureOfExpertsOpFP8PerChannel(
            num_experts,
            experts_min,
            experts_max,
        )
    if self.block_quant:
        layer = hpu_ops.fp8_block_moe_prepare_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8)
    else:
        layer = hpu_ops.fp8_channel_moe_prepare_weights(layer)