20
20
import torch .distributed as dist
21
21
from vllm .distributed .device_communicators .base_device_communicator import \
22
22
DeviceCommunicatorBase
23
+ from vllm .utils import logger
23
24
24
25
25
26
class NPUCommunicator (DeviceCommunicatorBase ):
@@ -34,6 +35,12 @@ def __init__(self,
34
35
# init device according to rank
35
36
self .device = torch .npu .current_device ()
36
37
38
+ if self .use_all2all :
39
+ from vllm .distributed .device_communicators .all2all import \
40
+ NaiveAll2AllManager
41
+ self .all2all_manager = NaiveAll2AllManager (self .cpu_group )
42
+ logger .info ("Using naive all2all manager." )
43
+
37
44
def all_to_all (self ,
38
45
input_ : torch .Tensor ,
39
46
scatter_dim : int = 0 ,
@@ -73,3 +80,17 @@ def all_to_all(self,
73
80
dist .all_to_all (output_list , input_list , group = self .device_group )
74
81
output_tensor = torch .cat (output_list , dim = gather_dim ).contiguous ()
75
82
return output_tensor
83
+
84
+ # TODO: Add ut for dispatch and combine
85
+ def dispatch (
86
+ self , hidden_states : torch .Tensor ,
87
+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
88
+ assert self .all2all_manager is not None
89
+ hidden_states , router_logits = self .all2all_manager .dispatch (
90
+ hidden_states , router_logits )
91
+ return hidden_states , router_logits
92
+
93
+ def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
94
+ assert self .all2all_manager is not None
95
+ hidden_states = self .all2all_manager .combine (hidden_states )
96
+ return hidden_states
0 commit comments