Skip to content

Commit 3a34b11

Browse files
authored
[0.9.1][Dist][Bugfix] Fix mc2 process group to resolve self.cpu_group is None (#1831)
### What this PR does / why we need it? This pr fixes the bug, which throw an error `self.cpu_group is None`. This is mainly caused by the wrong group ranks of process groups maintained in vllm-ascend. We need to take external dp size into account to ensure it work fine with `external_launch` mode. Related fixes: #1396 #1154 Signed-off-by: MengqingCao <[email protected]>
1 parent 5be1d8c commit 3a34b11

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,22 @@ def model_parallel_initialized():
1919

2020
def init_ascend_model_parallel(
2121
expert_parallel_size: int = 1,
22-
world_size: Optional[int] = None,
2322
backend: Optional[str] = None,
2423
):
2524
if model_parallel_initialized():
2625
return
2726
assert torch.distributed.is_initialized()
28-
world_size = world_size or torch.distributed.get_world_size()
27+
world_size = torch.distributed.get_world_size()
2928
backend = backend or torch.distributed.get_backend(
3029
get_world_group().device_group)
31-
num_expert_parallel_groups = world_size // expert_parallel_size
3230

31+
# The layout of all ranks: ExternalDP * EP
32+
# ExternalDP is the data parallel group that is not part of the model,
33+
# every dp rank can generate independently (in verl integration).
34+
all_ranks = torch.arange(world_size).reshape(-1, expert_parallel_size)
3335
global _MC2
34-
group_ranks = []
35-
for i in range(num_expert_parallel_groups):
36-
ranks = list(range(i, world_size, num_expert_parallel_groups))
37-
group_ranks.append(ranks)
36+
group_ranks = all_ranks.unbind(0)
37+
group_ranks = [x.tolist() for x in group_ranks]
3838

3939
_MC2 = init_model_parallel_group(group_ranks,
4040
get_world_group().local_rank,

vllm_ascend/worker/worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,7 @@ def _init_worker_distributed_environment(
546546
ensure_model_parallel_initialized(
547547
parallel_config.tensor_parallel_size,
548548
parallel_config.pipeline_parallel_size)
549-
init_ascend_model_parallel(parallel_config.expert_parallel_size,
550-
parallel_config.world_size_across_dp)
549+
init_ascend_model_parallel(parallel_config.expert_parallel_size)
551550
ensure_kv_transfer_initialized(vllm_config)
552551

553552

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,7 @@ def _init_worker_distributed_environment(self) -> None:
261261
ensure_model_parallel_initialized(
262262
self.parallel_config.tensor_parallel_size,
263263
self.parallel_config.pipeline_parallel_size)
264-
init_ascend_model_parallel(self.parallel_config.expert_parallel_size,
265-
self.parallel_config.world_size_across_dp)
264+
init_ascend_model_parallel(self.parallel_config.expert_parallel_size)
266265
ensure_kv_transfer_initialized(self.vllm_config)
267266

268267
def _init_profiler(self):

0 commit comments

Comments
 (0)