跳到内容

vllm_gaudi.distributed.device_communicators.hpu_communicator

HpuCommunicator

基类: DeviceCommunicatorBase

源代码位于 vllm_gaudi/distributed/device_communicators/hpu_communicator.py
class HpuCommunicator(DeviceCommunicatorBase):

    def __init__(self,
                 cpu_group: ProcessGroup,
                 device: Optional[torch.device] = None,
                 device_group: Optional[ProcessGroup] = None,
                 unique_name: str = ""):
        super().__init__(cpu_group, device, device_group, unique_name)

        self.dp_group: Optional[GroupCoordinator] = None
        self.dp_rank = 0
        self.dp_world_size = 1
        # assume EP is enabled along with DP
        if "ep" in unique_name:
            self.dp_group = get_dp_group()
            self.dp_rank = self.dp_group.rank_in_group
            self.dp_world_size = self.dp_group.world_size
            self.tp_group = get_tp_group()
        self.world_size = dist.get_world_size(group=self.cpu_group)
        self.rank = dist.get_rank(group=self.cpu_group)

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
        # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
        # (which is required for tensor parallel HPUGraph inference)
        htorch.core.mark_step()
        dist.all_reduce(input_, group=self.device_group)
        return input_

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # Allocate output tensor.
        output_tensor = torch.empty((world_size, ) + input_size, dtype=input_.dtype, device=input_.device)
        # All-gather.
        htorch.core.mark_step()
        dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
        # Reshape
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(input_size[:dim] + (world_size * input_size[dim], ) +
                                              input_size[dim + 1:])
        return output_tensor

    def dispatch(self,
                 hidden_states: torch.Tensor,
                 router_logits: torch.Tensor,
                 is_sequence_parallel: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        assert self.dp_group is not None
        assert hidden_states.dim() == 2, "Input hidden states must be 2D"

        dp_metadata = get_hpu_dp_metadata()
        if dp_metadata is not None:
            hidden_states_across_dp = dp_metadata.hidden_states_across_dp
            router_logits_across_dp = dp_metadata.router_logits_across_dp
        else:
            # create hidden_states_across_dp tensor
            input_size = hidden_states.size()
            # Allocate output tensor.
            output_size = list(input_size)
            if is_sequence_parallel:
                # if sequence parallel enabled, hidden states was already being chunked by sp_size
                output_size[0] *= self.world_size
            else:
                output_size[0] *= self.dp_world_size
            hidden_states_across_dp = torch.empty(output_size, dtype=hidden_states.dtype, device=hidden_states.device)

            # create router_logits_across_dp tensor
            router_logits_size = router_logits.size()
            router_logits_output_size = list(router_logits_size)
            if is_sequence_parallel:
                router_logits_output_size[0] *= self.world_size
            else:
                router_logits_output_size[0] *= self.dp_world_size
            router_logits_across_dp = torch.empty(router_logits_output_size,
                                                  dtype=router_logits.dtype,
                                                  device=router_logits.device)

        torch.distributed.all_gather_into_tensor(
            hidden_states_across_dp,
            hidden_states,
            group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)

        torch.distributed.all_gather_into_tensor(
            router_logits_across_dp,
            router_logits,
            group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)
        return hidden_states_across_dp, router_logits_across_dp

    def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False) -> torch.Tensor:
        if htorch.utils.internal.is_lazy():
            htorch.core.mark_step()
        assert self.dp_group is not None
        assert hidden_states.dim() == 2, "Input hidden states must be 2D"

        dp_metadata = get_hpu_dp_metadata()
        if dp_metadata is not None:
            local_hidden_states = dp_metadata.local_hidden_states
        else:
            local_num_tokens = hidden_states.size(0) // self.world_size if is_sequence_parallel else hidden_states.size(
                0) // self.dp_world_size
            local_hidden_states = torch.empty((local_num_tokens, hidden_states.size(-1)),
                                              device=hidden_states.device,
                                              dtype=hidden_states.dtype)

        torch.distributed.reduce_scatter_tensor(
            local_hidden_states,
            hidden_states,
            group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)
        hidden_states = local_hidden_states
        return hidden_states

dp_group 实例属性

dp_group: Optional[GroupCoordinator] = None

dp_rank 实例属性

dp_rank = 0

dp_world_size 实例属性

dp_world_size = 1

rank 实例属性

rank = get_rank(group=cpu_group)

tp_group 实例属性

tp_group = get_tp_group()

world_size 实例属性

world_size = get_world_size(group=cpu_group)

__init__

__init__(
    cpu_group: ProcessGroup,
    device: Optional[device] = None,
    device_group: Optional[ProcessGroup] = None,
    unique_name: str = "",
)
源代码位于 vllm_gaudi/distributed/device_communicators/hpu_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: Optional[torch.device] = None,
             device_group: Optional[ProcessGroup] = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)

    self.dp_group: Optional[GroupCoordinator] = None
    self.dp_rank = 0
    self.dp_world_size = 1
    # assume EP is enabled along with DP
    if "ep" in unique_name:
        self.dp_group = get_dp_group()
        self.dp_rank = self.dp_group.rank_in_group
        self.dp_world_size = self.dp_group.world_size
        self.tp_group = get_tp_group()
    self.world_size = dist.get_world_size(group=self.cpu_group)
    self.rank = dist.get_rank(group=self.cpu_group)

all_gather

all_gather(input_: Tensor, dim: int = -1) -> Tensor
源代码位于 vllm_gaudi/distributed/device_communicators/hpu_communicator.py
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
    world_size = self.world_size
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # Allocate output tensor.
    output_tensor = torch.empty((world_size, ) + input_size, dtype=input_.dtype, device=input_.device)
    # All-gather.
    htorch.core.mark_step()
    dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
    # Reshape
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] + (world_size * input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor

all_reduce

all_reduce(input_: Tensor) -> Tensor
源代码位于 vllm_gaudi/distributed/device_communicators/hpu_communicator.py
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
    # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
    # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
    # (which is required for tensor parallel HPUGraph inference)
    htorch.core.mark_step()
    dist.all_reduce(input_, group=self.device_group)
    return input_

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor
源代码位于 vllm_gaudi/distributed/device_communicators/hpu_communicator.py
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False) -> torch.Tensor:
    if htorch.utils.internal.is_lazy():
        htorch.core.mark_step()
    assert self.dp_group is not None
    assert hidden_states.dim() == 2, "Input hidden states must be 2D"

    dp_metadata = get_hpu_dp_metadata()
    if dp_metadata is not None:
        local_hidden_states = dp_metadata.local_hidden_states
    else:
        local_num_tokens = hidden_states.size(0) // self.world_size if is_sequence_parallel else hidden_states.size(
            0) // self.dp_world_size
        local_hidden_states = torch.empty((local_num_tokens, hidden_states.size(-1)),
                                          device=hidden_states.device,
                                          dtype=hidden_states.dtype)

    torch.distributed.reduce_scatter_tensor(
        local_hidden_states,
        hidden_states,
        group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)
    hidden_states = local_hidden_states
    return hidden_states

dispatch

dispatch(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
) -> tuple[Tensor, Tensor]
源代码位于 vllm_gaudi/distributed/device_communicators/hpu_communicator.py
def dispatch(self,
             hidden_states: torch.Tensor,
             router_logits: torch.Tensor,
             is_sequence_parallel: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
    assert self.dp_group is not None
    assert hidden_states.dim() == 2, "Input hidden states must be 2D"

    dp_metadata = get_hpu_dp_metadata()
    if dp_metadata is not None:
        hidden_states_across_dp = dp_metadata.hidden_states_across_dp
        router_logits_across_dp = dp_metadata.router_logits_across_dp
    else:
        # create hidden_states_across_dp tensor
        input_size = hidden_states.size()
        # Allocate output tensor.
        output_size = list(input_size)
        if is_sequence_parallel:
            # if sequence parallel enabled, hidden states was already being chunked by sp_size
            output_size[0] *= self.world_size
        else:
            output_size[0] *= self.dp_world_size
        hidden_states_across_dp = torch.empty(output_size, dtype=hidden_states.dtype, device=hidden_states.device)

        # create router_logits_across_dp tensor
        router_logits_size = router_logits.size()
        router_logits_output_size = list(router_logits_size)
        if is_sequence_parallel:
            router_logits_output_size[0] *= self.world_size
        else:
            router_logits_output_size[0] *= self.dp_world_size
        router_logits_across_dp = torch.empty(router_logits_output_size,
                                              dtype=router_logits.dtype,
                                              device=router_logits.device)

    torch.distributed.all_gather_into_tensor(
        hidden_states_across_dp,
        hidden_states,
        group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)

    torch.distributed.all_gather_into_tensor(
        router_logits_across_dp,
        router_logits,
        group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)
    return hidden_states_across_dp, router_logits_across_dp