Skip to content

Commit ff33d48

Browse files
committed
[MoE][Dist] Fix Qwen MoE accuracy bug in DP scenario
Signed-off-by: MengqingCao <[email protected]>
1 parent 72eceff commit ff33d48

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

vllm_ascend/distributed/communicator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed as dist
2121
from vllm.distributed.device_communicators.base_device_communicator import \
2222
DeviceCommunicatorBase
23+
from vllm.utils import logger
2324

2425

2526
class NPUCommunicator(DeviceCommunicatorBase):
@@ -34,6 +35,12 @@ def __init__(self,
3435
# init device according to rank
3536
self.device = torch.npu.current_device()
3637

38+
if self.use_all2all:
39+
from vllm.distributed.device_communicators.all2all import \
40+
NaiveAll2AllManager
41+
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
42+
logger.info("Using naive all2all manager.")
43+
3744
def all_to_all(self,
3845
input_: torch.Tensor,
3946
scatter_dim: int = 0,
@@ -73,3 +80,17 @@ def all_to_all(self,
7380
dist.all_to_all(output_list, input_list, group=self.device_group)
7481
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
7582
return output_tensor
83+
84+
# TODO: Add ut for dispatch and combine
85+
def dispatch(
86+
self, hidden_states: torch.Tensor,
87+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
88+
assert self.all2all_manager is not None
89+
hidden_states, router_logits = self.all2all_manager.dispatch(
90+
hidden_states, router_logits)
91+
return hidden_states, router_logits
92+
93+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
94+
assert self.all2all_manager is not None
95+
hidden_states = self.all2all_manager.combine(hidden_states)
96+
return hidden_states

0 commit comments

Comments
 (0)