Skip to content

Commit 434e8bf

Browse files
committed
refactor: use platform-agnostic device control for DP engine core
Refactor the DP engine core to use the platform-specific `device_control_env_var` attribute instead of hardcoding CUDA_VISIBLE_DEVICES. This change improves platform compatibility and code maintainability. The update includes: 1. Moving the `device_id_to_physical_device_id` function to a shared utils file 2. Updating imports in CUDA and ROCm platform files 3. Replacing CUDA-specific environment variable setting with a platform-agnostic approach in the DPEngineCoreProc class This refactoring enhances the flexibility of the codebase to support different platforms more seamlessly. Signed-off-by: Jade Zheng <[email protected]>
1 parent 93a126f commit 434e8bf

File tree

4 files changed

+35
-39
lines changed

4 files changed

+35
-39
lines changed

vllm/platforms/cuda.py

+5-22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.utils import import_pynvml
1919

2020
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
21+
from .utils import device_id_to_physical_device_id
2122

2223
if TYPE_CHECKING:
2324
from vllm.config import ModelConfig, VllmConfig
@@ -37,24 +38,6 @@
3738
torch.backends.cuda.enable_cudnn_sdp(False)
3839

3940

40-
def device_id_to_physical_device_id(device_id: int) -> int:
41-
if "CUDA_VISIBLE_DEVICES" in os.environ:
42-
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
43-
if device_ids == [""]:
44-
msg = (
45-
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
46-
" GPU support is disabled. If you are using ray, please unset"
47-
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
48-
" worker/actor. "
49-
"Check https://github.com/vllm-project/vllm/issues/8402 for"
50-
" more information.")
51-
raise RuntimeError(msg)
52-
physical_device_id = device_ids[device_id]
53-
return int(physical_device_id)
54-
else:
55-
return device_id
56-
57-
5841
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
5942

6043
@wraps(fn)
@@ -328,7 +311,7 @@ def get_device_capability(cls,
328311
device_id: int = 0
329312
) -> Optional[DeviceCapability]:
330313
try:
331-
physical_device_id = device_id_to_physical_device_id(device_id)
314+
physical_device_id = device_id_to_physical_device_id(device_id, cls.device_control_env_var)
332315
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
333316
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
334317
return DeviceCapability(major=major, minor=minor)
@@ -350,20 +333,20 @@ def has_device_capability(
350333
@classmethod
351334
@with_nvml_context
352335
def get_device_name(cls, device_id: int = 0) -> str:
353-
physical_device_id = device_id_to_physical_device_id(device_id)
336+
physical_device_id = device_id_to_physical_device_id(device_id, cls.device_control_env_var)
354337
return cls._get_physical_device_name(physical_device_id)
355338

356339
@classmethod
357340
@with_nvml_context
358341
def get_device_uuid(cls, device_id: int = 0) -> str:
359-
physical_device_id = device_id_to_physical_device_id(device_id)
342+
physical_device_id = device_id_to_physical_device_id(device_id, cls.device_control_env_var)
360343
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
361344
return pynvml.nvmlDeviceGetUUID(handle)
362345

363346
@classmethod
364347
@with_nvml_context
365348
def get_device_total_memory(cls, device_id: int = 0) -> int:
366-
physical_device_id = device_id_to_physical_device_id(device_id)
349+
physical_device_id = device_id_to_physical_device_id(device_id, cls.device_control_env_var)
367350
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
368351
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
369352

vllm/platforms/rocm.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.logger import init_logger
1111

1212
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
13+
from .utils import device_id_to_physical_device_id
1314

1415
if TYPE_CHECKING:
1516
from vllm.config import ModelConfig, VllmConfig
@@ -89,15 +90,6 @@ def wrapper(*args, **kwargs):
8990
return wrapper
9091

9192

92-
def device_id_to_physical_device_id(device_id: int) -> int:
93-
if "CUDA_VISIBLE_DEVICES" in os.environ:
94-
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
95-
physical_device_id = device_ids[device_id]
96-
return int(physical_device_id)
97-
else:
98-
return device_id
99-
100-
10193
def on_mi250_mi300() -> bool:
10294
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
10395
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
@@ -223,7 +215,7 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool:
223215
@with_amdsmi_context
224216
@lru_cache(maxsize=8)
225217
def get_device_name(cls, device_id: int = 0) -> str:
226-
physical_device_id = device_id_to_physical_device_id(device_id)
218+
physical_device_id = device_id_to_physical_device_id(device_id, cls.device_control_env_var)
227219
handle = amdsmi_get_processor_handles()[physical_device_id]
228220
return amdsmi_get_gpu_asic_info(handle)["market_name"]
229221

vllm/platforms/utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
4+
5+
def device_id_to_physical_device_id(device_id: int,
6+
device_control_env_var: str) -> int:
7+
if device_control_env_var in os.environ:
8+
device_ids = os.environ[device_control_env_var].split(",")
9+
if device_ids == [""]:
10+
msg = (
11+
f"{device_control_env_var} is set to empty string, which means"
12+
" current platform support is disabled. If you are using ray,"
13+
f" please unset the environment variable `{device_control_env_var}`"
14+
"inside the worker/actor. "
15+
"Check https://github.com/vllm-project/vllm/issues/8402 for"
16+
" more information.")
17+
raise RuntimeError(msg)
18+
physical_device_id = device_ids[device_id]
19+
return int(physical_device_id)
20+
else:
21+
return device_id

vllm/v1/engine/core.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -592,13 +592,13 @@ def __init__(
592592
assert 0 <= local_dp_rank <= dp_rank < dp_size
593593

594594
from vllm.platforms import current_platform
595-
if current_platform.is_cuda_alike():
596-
from vllm.platforms.cuda import device_id_to_physical_device_id
597-
tp_size = vllm_config.parallel_config.tensor_parallel_size
598-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
599-
str(device_id_to_physical_device_id(i))
600-
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
601-
tp_size))
595+
from vllm.platforms.utils import device_id_to_physical_device_id
596+
device_control_env_var = current_platform.device_control_env_var
597+
tp_size = vllm_config.parallel_config.tensor_parallel_size
598+
os.environ[device_control_env_var] = ",".join(
599+
str(device_id_to_physical_device_id(i, device_control_env_var))
600+
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
601+
tp_size))
602602

603603
self.local_dp_rank = local_dp_rank
604604
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

0 commit comments

Comments
 (0)