From aa3efaac46368065c47bd2c6a4e46eac12ff9c26 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 10 Jul 2025 08:07:08 -0400 Subject: [PATCH 1/6] Add pytorch symm memory communicator Signed-off-by: ilmarkov --- .../device_communicators/cuda_communicator.py | 15 +++ .../device_communicators/custom_all_reduce.py | 10 ++ .../device_communicators/symm_mem.py | 96 +++++++++++++++++++ vllm/envs.py | 5 + 4 files changed, 126 insertions(+) create mode 100644 vllm/distributed/device_communicators/symm_mem.py diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4ab8f3d938fc..8a5a5cc25ae3 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -44,6 +44,8 @@ def __init__(self, PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -54,6 +56,7 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -69,6 +72,12 @@ def __init__(self, # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if envs.VLLM_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -105,6 +114,12 @@ def all_reduce(self, input_): out = ca_comm.custom_all_reduce(input_) assert out is not None return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and not symm_mem_comm.disabled and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7dd104a4fcc4..704f3ae86314 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -49,6 +49,14 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + MB = 1024 * 1024 + # Max sizes for each world size in case symmetric memory is available + _MAX_SIZES = { + 2: MB, # 1 MB + 4: MB, # 1 MB + 6: MB // 2, # 512 KB + 8: MB // 2, # 512 KB + } # max_size: max supported allreduce size def __init__(self, @@ -109,6 +117,8 @@ def __init__(self, # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device + if current_platform.is_cuda() and envs.VLLM_USE_SYMM_MEM: + max_size = CustomAllreduce._MAX_SIZES[world_size] cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py new file mode 100644 index 000000000000..36ffbb9ed8d6 --- /dev/null +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.logger import init_logger + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + +logger = init_logger(__name__) + + +class SymmMemCommunicator: + MB = 1024 * 1024 + # Max sizes for each world size + _MAX_SIZES = { + 2: 8 * MB, + 4: 32 * MB, + 6: 64 * MB, + 8: 256 * MB, + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, + torch.device]): + self.disabled = True + + if not symm_mem_available: + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + if self.world_size not in self._MAX_SIZES: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.buffer = torch_symm_mem.empty( + self._MAX_SIZES[self.world_size] // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning("SymmMemCommunicator: symmetric memory " + "multicast operations are not supported.") + return + self.disabled = False + + def should_use_symm_mem(self, inp: torch.Tensor): + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + return inp_size <= self._MAX_SIZES[self.world_size] + + def all_reduce( + self, + inp: torch.Tensor, + *, + out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + if not self.should_use_symm_mem(inp): + return None + if out is None: + out = torch.empty_like(inp) + self.buffer[:inp.numel()].copy_(inp.view(-1)) + if self.world_size in [2, 4]: + # Use two-shot all-reduce for 2 and 4 GPUs + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + else: + # Use multi-mem all-reduce for 6 and 8 GPUs + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + return out diff --git a/vllm/envs.py b/vllm/envs.py index 212eaf015a83..9d5cf26bee09 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -155,6 +155,7 @@ VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_USE_SYMM_MEM: bool = False def get_default_cache_root(): @@ -1093,6 +1094,10 @@ def get_vllm_port() -> Optional[int]: # never removed from memory until the server terminates. "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + + # Whether to use pytorch symmetric memory for allreduce + "VLLM_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_USE_SYMM_MEM", "0"))), } # --8<-- [end:env-vars-definition] From ff33a5ad5c125d6763ee1f9ced681efec4234673 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 10 Jul 2025 09:44:42 -0400 Subject: [PATCH 2/6] Update symm mem constants and funcs Signed-off-by: ilmarkov --- .../device_communicators/custom_all_reduce.py | 10 +++++----- vllm/distributed/device_communicators/symm_mem.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 704f3ae86314..08dec65e9801 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -49,13 +49,13 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - MB = 1024 * 1024 + MiB = 1024 * 1024 # Max sizes for each world size in case symmetric memory is available _MAX_SIZES = { - 2: MB, # 1 MB - 4: MB, # 1 MB - 6: MB // 2, # 512 KB - 8: MB // 2, # 512 KB + 2: 2 * MiB, # 1 MB + 4: 2 * MiB, # 1 MB + 6: MiB, # 512 KB + 8: MiB // 2, # 512 KB } # max_size: max supported allreduce size diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index 36ffbb9ed8d6..f5ca45a7e880 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -19,13 +19,13 @@ class SymmMemCommunicator: - MB = 1024 * 1024 + MiB = 1024 * 1024 # Max sizes for each world size _MAX_SIZES = { - 2: 8 * MB, - 4: 32 * MB, - 6: 64 * MB, - 8: 256 * MB, + 2: 8 * MiB, + 4: 32 * MiB, + 6: 128 * MiB, + 8: 128 * MiB, } def __init__(self, group: ProcessGroup, device: Union[int, str, From e2e8e0ca13af6f7eb66e62c9d7468746342b72d0 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 10 Jul 2025 10:32:32 -0400 Subject: [PATCH 3/6] Rename env Signed-off-by: ilmarkov --- vllm/distributed/device_communicators/cuda_communicator.py | 2 +- vllm/distributed/device_communicators/custom_all_reduce.py | 2 +- vllm/envs.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 8a5a5cc25ae3..a05915a1c2ea 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -72,7 +72,7 @@ def __init__(self, # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) - if envs.VLLM_USE_SYMM_MEM and current_platform.is_cuda(): + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, device=self.device, diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 08dec65e9801..0d1f4a63cad0 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -117,7 +117,7 @@ def __init__(self, # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - if current_platform.is_cuda() and envs.VLLM_USE_SYMM_MEM: + if current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM: max_size = CustomAllreduce._MAX_SIZES[world_size] cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES diff --git a/vllm/envs.py b/vllm/envs.py index 9d5cf26bee09..47f50981dfcc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -155,7 +155,7 @@ VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False - VLLM_USE_SYMM_MEM: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False def get_default_cache_root(): @@ -1096,8 +1096,8 @@ def get_vllm_port() -> Optional[int]: lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), # Whether to use pytorch symmetric memory for allreduce - "VLLM_USE_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_USE_SYMM_MEM", "0"))), + "VLLM_ALLREDUCE_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), } # --8<-- [end:env-vars-definition] From f3a267cf560d39c35fe737b12c2439a7fd7736df Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 10 Jul 2025 10:34:35 -0400 Subject: [PATCH 4/6] Upd Signed-off-by: ilmarkov --- vllm/distributed/device_communicators/symm_mem.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index f5ca45a7e880..be8349615692 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -7,6 +7,7 @@ from torch.distributed import ProcessGroup from vllm.logger import init_logger +from vllm.platforms import current_platform try: import torch.distributed._symmetric_memory as torch_symm_mem @@ -34,6 +35,11 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, if not symm_mem_available: return + + if not current_platform.is_cuda(): + logger.warning("SymmMemCommunicator: symmetric " + "memory is not available.") + return if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): From 0bf30025f61162300ec3302de1e54065026e035e Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 22 Jul 2025 05:50:04 -0400 Subject: [PATCH 5/6] Fixes after rebase Signed-off-by: ilmarkov --- tests/distributed/test_symm_mem_allreduce.py | 108 +++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/distributed/test_symm_mem_allreduce.py diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py new file mode 100644 index 000000000000..5a804a389123 --- /dev/null +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + symm_mem_comm = cuda_communicator.symm_mem_comm + if symm_mem_comm is None or symm_mem_comm.disabled: + pytest.skip("SymmMemCommunicator is not available or disabled.") + + inp_direct_symm_mem = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): + pytest.skip( + "SymmMemCommunicator isn't used for this world and input size." + ) + + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) + assert out_direct_symm_mem is not None + + group = get_tensor_model_parallel_group().device_group + dist.all_reduce(original_inp_direct_symm_mem, group=group) + torch.testing.assert_close(out_direct_symm_mem, + original_inp_direct_symm_mem, + atol=2.5, + rtol=0.1) + + # Test tensor_model_parallel_all_reduce which should use symm_mem + inp_tensor_parallel = torch.randint(-23, + 1, (test_size_elements, ), + dtype=dtype, + device=device) + original_inp_tensor_parallel = inp_tensor_parallel.clone() + out_tensor_parallel = tensor_model_parallel_all_reduce( + inp_tensor_parallel) + dist.all_reduce(original_inp_tensor_parallel, group=group) + torch.testing.assert_close(out_tensor_parallel, + original_inp_tensor_parallel, + atol=2.5, + rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") + + mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) + cleanup_dist_env_and_memory() From f5b5f42ab248f4e39fc48f470beb4f61bff47013 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 22 Jul 2025 15:46:12 +0000 Subject: [PATCH 6/6] Add hopper support Signed-off-by: ilmarkov --- docs/design/multiprocessing.md | 2 +- tools/check_pickle_imports.py | 2 +- ...ll_reduce_utils.py => all_reduce_utils.py} | 33 ++++++++++++++++ .../device_communicators/cuda_communicator.py | 2 +- .../device_communicators/custom_all_reduce.py | 22 +++++------ .../device_communicators/symm_mem.py | 39 ++++++++++++------- 6 files changed, 69 insertions(+), 31 deletions(-) rename vllm/distributed/device_communicators/{custom_all_reduce_utils.py => all_reduce_utils.py} (93%) diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 06ebd7725858..247072d1cb27 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. There are other miscellaneous places hard-coding the use of `spawn`: -- +- - Related PRs: diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index 5e99dc63ebe0..ca582a6aabcb 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -37,7 +37,7 @@ 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/custom_all_reduce_utils.py', + 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py similarity index 93% rename from vllm/distributed/device_communicators/custom_all_reduce_utils.py rename to vllm/distributed/device_communicators/all_reduce_utils.py index 7c6001e87039..5c64e7d5c4ba 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,6 +23,39 @@ logger = init_logger(__name__) +MiB = 1024 * 1024 +# Max size for each world size in case symmetric memory is available +# For different SM architectures +CUSTOM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: MiB // 2, # 512 KB + 8: MiB // 4, # 256 KB + }, + "10.0": { + 2: 2 * MiB, # 2 MB + 4: 2 * MiB, # 2 MB + 6: 2 * MiB, # 2 MB + 8: 2 * MiB, # 2 MB + } +} + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + "10.0": { + 2: 8 * MiB, # 8 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + } +} + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index a05915a1c2ea..8c5c356da2d8 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -115,7 +115,7 @@ def all_reduce(self, input_): assert out is not None return out symm_mem_comm = self.symm_mem_comm - if symm_mem_comm is not None and not symm_mem_comm.disabled and \ + if symm_mem_comm is not None and \ symm_mem_comm.should_use_symm_mem(input_): out = symm_mem_comm.all_reduce(input_) assert out is not None diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 0d1f4a63cad0..533512e2636c 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,8 +10,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.custom_all_reduce_utils import ( - gpu_p2p_access_check) +from vllm.distributed.device_communicators.all_reduce_utils import ( + CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -49,14 +49,6 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - MiB = 1024 * 1024 - # Max sizes for each world size in case symmetric memory is available - _MAX_SIZES = { - 2: 2 * MiB, # 1 MB - 4: 2 * MiB, # 1 MB - 6: MiB, # 512 KB - 8: MiB // 2, # 512 KB - } # max_size: max supported allreduce size def __init__(self, @@ -117,9 +109,13 @@ def __init__(self, # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - if current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM: - max_size = CustomAllreduce._MAX_SIZES[world_size] - + device_capability = current_platform.get_device_capability( + ).as_version_str() + if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], + max_size) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index be8349615692..d907e1b833d0 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -6,6 +6,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +from vllm.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -20,13 +22,9 @@ class SymmMemCommunicator: - MiB = 1024 * 1024 - # Max sizes for each world size - _MAX_SIZES = { - 2: 8 * MiB, - 4: 32 * MiB, - 6: 128 * MiB, - 8: 128 * MiB, + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], } def __init__(self, group: ProcessGroup, device: Union[int, str, @@ -49,15 +47,27 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, self.device = device self.group = group self.world_size = dist.get_world_size(self.group) - if self.world_size not in self._MAX_SIZES: + self.device_capability = current_platform.get_device_capability( + ).as_version_str() + if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability]: logger.warning( "SymmMemCommunicator: World size %d not supported, " "communicator is not available.", self.world_size, ) return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size] self.buffer = torch_symm_mem.empty( - self._MAX_SIZES[self.world_size] // self.dtype.itemsize, + self.max_size // self.dtype.itemsize, device=self.device, dtype=self.dtype, ) @@ -76,7 +86,7 @@ def should_use_symm_mem(self, inp: torch.Tensor): inp_size = inp.numel() * inp.element_size() if inp_size % 4 != 0: return False - return inp_size <= self._MAX_SIZES[self.world_size] + return inp_size < self.max_size def all_reduce( self, @@ -88,14 +98,13 @@ def all_reduce( if out is None: out = torch.empty_like(inp) self.buffer[:inp.numel()].copy_(inp.view(-1)) - if self.world_size in [2, 4]: - # Use two-shot all-reduce for 2 and 4 GPUs - torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + if self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], "sum", self.group.group_name) else: - # Use multi-mem all-reduce for 6 and 8 GPUs - torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], "sum", self.group.group_name) out.copy_(self.buffer[:inp.numel()].view(out.shape))