Skip to content

Commit 9abfa65

Browse files
committed
[Dist][Bugfix] Fix mc2 process group
Signed-off-by: MengqingCao <[email protected]>
1 parent 89129a8 commit 9abfa65

File tree

3 files changed

+4
-13
lines changed

3 files changed

+4
-13
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

33
import torch
4-
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
4+
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, get_ep_group,
55
init_model_parallel_group)
66

77
# Currently, mc2 op need their own group coordinator.
@@ -18,23 +18,16 @@ def model_parallel_initialized():
1818

1919

2020
def init_ascend_model_parallel(
21-
expert_parallel_size: int = 1,
22-
world_size: Optional[int] = None,
2321
backend: Optional[str] = None,
2422
):
2523
if model_parallel_initialized():
2624
return
2725
assert torch.distributed.is_initialized()
28-
world_size = world_size or torch.distributed.get_world_size()
2926
backend = backend or torch.distributed.get_backend(
3027
get_world_group().device_group)
31-
num_expert_parallel_groups = world_size // expert_parallel_size
3228

3329
global _MC2
34-
group_ranks = []
35-
for i in range(num_expert_parallel_groups):
36-
ranks = list(range(i, world_size, num_expert_parallel_groups))
37-
group_ranks.append(ranks)
30+
group_ranks = get_ep_group().ranks
3831

3932
_MC2 = init_model_parallel_group(group_ranks,
4033
get_world_group().local_rank,

vllm_ascend/worker/worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,7 @@ def _init_worker_distributed_environment(
546546
ensure_model_parallel_initialized(
547547
parallel_config.tensor_parallel_size,
548548
parallel_config.pipeline_parallel_size)
549-
init_ascend_model_parallel(parallel_config.expert_parallel_size,
550-
parallel_config.world_size_across_dp)
549+
init_ascend_model_parallel()
551550
ensure_kv_transfer_initialized(vllm_config)
552551

553552

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,7 @@ def _init_worker_distributed_environment(self) -> None:
261261
ensure_model_parallel_initialized(
262262
self.parallel_config.tensor_parallel_size,
263263
self.parallel_config.pipeline_parallel_size)
264-
init_ascend_model_parallel(self.parallel_config.expert_parallel_size,
265-
self.parallel_config.world_size_across_dp)
264+
init_ascend_model_parallel()
266265
ensure_kv_transfer_initialized(self.vllm_config)
267266

268267
def _init_profiler(self):

0 commit comments

Comments
 (0)