Skip to content

Commit 8145cea

Browse files
authored
Merge branch 'vllm-project:main' into feat/command-tool-parser
2 parents d9eb8e1 + 0b40747 commit 8145cea

File tree

9 files changed

+43
-4
lines changed

9 files changed

+43
-4
lines changed

vllm/platforms/cpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
7575
def get_device_total_memory(cls, device_id: int = 0) -> int:
7676
return psutil.virtual_memory().total
7777

78+
@classmethod
79+
def set_device(cls, device: torch.device) -> None:
80+
"""
81+
Set the device for the current platform.
82+
"""
83+
torch.cpu.set_device(device)
84+
7885
@classmethod
7986
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
8087
return False

vllm/platforms/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def set_device(cls, device: torch.device) -> None:
7777
"""
7878
Set the device for the current platform.
7979
"""
80-
super().set_device(device)
80+
torch.cuda.set_device(device)
8181
# With this trick we can force the device to be set eagerly
8282
# see https://github.com/pytorch/pytorch/issues/155668
8383
# for why and when it is needed

vllm/platforms/hpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
4545
def inference_mode(cls):
4646
return torch.no_grad()
4747

48+
@classmethod
49+
def set_device(cls, device: torch.device) -> None:
50+
"""
51+
Set the device for the current platform.
52+
"""
53+
torch.hpu.set_device(device)
54+
4855
@classmethod
4956
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5057

vllm/platforms/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def set_device(cls, device: torch.device) -> None:
305305
"""
306306
Set the device for the current platform.
307307
"""
308-
torch.cuda.set_device(device)
308+
raise NotImplementedError
309309

310310
@classmethod
311311
def pre_register_and_update(cls,

vllm/platforms/rocm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,17 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
241241
logger.info("Using ROCmFlashAttention backend.")
242242
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
243243

244+
@classmethod
245+
def set_device(cls, device: torch.device) -> None:
246+
"""
247+
Set the device for the current platform.
248+
"""
249+
torch.cuda.set_device(device)
250+
# With this trick we can force the device to be set eagerly
251+
# see https://github.com/pytorch/pytorch/issues/155668
252+
# for why and when it is needed
253+
_ = torch.zeros(1, device=device)
254+
244255
@classmethod
245256
@lru_cache(maxsize=8)
246257
def get_device_capability(cls,

vllm/platforms/tpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
5555
logger.info("Using Pallas V1 backend.")
5656
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
5757

58+
@classmethod
59+
def set_device(cls, device: torch.device) -> None:
60+
"""
61+
Set the device for the current platform.
62+
"""
63+
torch.tpu.set_device(device)
64+
5865
@classmethod
5966
def get_device_name(cls, device_id: int = 0) -> str:
6067
chip_type, _ = device.get_local_chips()

vllm/platforms/xpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
4545
logger.info("Using Flash Attention backend on V1 engine.")
4646
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
4747

48+
@classmethod
49+
def set_device(cls, device: torch.device) -> None:
50+
"""
51+
Set the device for the current platform.
52+
"""
53+
torch.xpu.set_device(device)
54+
4855
@classmethod
4956
def get_device_capability(
5057
cls,

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def init_device(self):
130130
# This env var set by Ray causes exceptions with graph building.
131131
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
132132
self.device = torch.device(f"cuda:{self.local_rank}")
133-
torch.cuda.set_device(self.device)
133+
current_platform.set_device(self.device)
134134

135135
_check_if_gpu_supports_dtype(self.model_config.dtype)
136136
gc.collect()

vllm/v1/worker/xpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def init_device(self):
132132
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
133133
):
134134
self.device = torch.device(f"xpu:{self.local_rank}")
135-
torch.xpu.set_device(self.device)
135+
current_platform.set_device(self.device)
136136
torch.xpu.empty_cache()
137137
self.init_gpu_memory = torch.xpu.get_device_properties(
138138
self.local_rank).total_memory

0 commit comments

Comments
 (0)