diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 065e4d1124..181231ed23 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -19,22 +19,22 @@ def model_parallel_initialized(): def init_ascend_model_parallel( expert_parallel_size: int = 1, - world_size: Optional[int] = None, backend: Optional[str] = None, ): if model_parallel_initialized(): return assert torch.distributed.is_initialized() - world_size = world_size or torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - num_expert_parallel_groups = world_size // expert_parallel_size + # The layout of all ranks: ExternalDP * EP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + all_ranks = torch.arange(world_size).reshape(-1, expert_parallel_size) global _MC2 - group_ranks = [] - for i in range(num_expert_parallel_groups): - ranks = list(range(i, world_size, num_expert_parallel_groups)) - group_ranks.append(ranks) + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index d72613601b..f38ed75a0d 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -546,8 +546,7 @@ def _init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - init_ascend_model_parallel(parallel_config.expert_parallel_size, - parallel_config.world_size_across_dp) + init_ascend_model_parallel(parallel_config.expert_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index dc004f4b89..d8918abf9c 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -261,8 +261,7 @@ def _init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) - init_ascend_model_parallel(self.parallel_config.expert_parallel_size, - self.parallel_config.world_size_across_dp) + init_ascend_model_parallel(self.parallel_config.expert_parallel_size) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):