Skip to content

Commit 05bdcbe

Browse files
fakeYanyiz-liu
andauthored
support aclgraph (#426)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <[email protected]> Signed-off-by: Yizhou Liu <[email protected]> Co-authored-by: Yizhou Liu <[email protected]>
1 parent 5c6d05a commit 05bdcbe

15 files changed

+447
-112
lines changed

.github/workflows/vllm_ascend_test.yaml

+8-6
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,26 @@ jobs:
115115
- name: Install vllm-project/vllm-ascend
116116
run: |
117117
pip install -r requirements-dev.txt
118-
pip install -e .
118+
pip install -v --no-build-isolation -e .
119119
120-
- name: Run vllm-project/vllm-ascend test on V0 engine
120+
- name: Run vllm-project/vllm-ascend test for V1 Engine
121121
env:
122-
VLLM_USE_V1: 0
122+
VLLM_USE_V1: 1
123+
VLLM_WORKER_MULTIPROC_METHOD: spawn
123124
run: |
124125
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
125126
pytest -sv tests/singlecard/test_offline_inference.py
126127
pytest -sv tests/ops
128+
pytest -sv tests/compile
127129
else
128130
pytest -sv tests/multicard/test_offline_inference_distributed.py
129131
pytest -sv tests/ops
132+
pytest -sv tests/compile
130133
fi
131134
132-
- name: Run vllm-project/vllm-ascend test for V1 Engine
135+
- name: Run vllm-project/vllm-ascend test on V0 engine
133136
env:
134-
VLLM_USE_V1: 1
135-
VLLM_WORKER_MULTIPROC_METHOD: spawn
137+
VLLM_USE_V1: 0
136138
run: |
137139
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
138140
pytest -sv tests/singlecard/test_offline_inference.py

csrc/ops.h

+18-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <vector>
2323
#include "kernels/types.h"
24+
#include "torch_npu/csrc/aten/common/from_blob.h"
2425

2526
namespace vllm_ascend {
2627
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
@@ -29,4 +30,20 @@ namespace vllm_ascend {
2930
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
3031
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
3132
uint32_t aivNum);
32-
}
33+
34+
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
35+
if (!tensor.is_privateuseone()) {
36+
throw std::runtime_error("Tensor must be on NPU device");
37+
}
38+
// Get the raw data pointer
39+
void* data_ptr = tensor.data_ptr();
40+
// Get tensor sizes and strides
41+
std::vector<int64_t> sizes = tensor.sizes().vec();
42+
std::vector<int64_t> strides = tensor.strides().vec();
43+
// Get tensor options (dtype, device)
44+
auto options = tensor.options();
45+
// Create a new tensor from the raw data pointer
46+
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
47+
return new_tensor;
48+
}
49+
}

csrc/torch_binding.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
103103
TORCH_LIBRARY_EXPAND(_C, ops)
104104
{
105105
// vLLM-Ascend custom ops
106+
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
107+
ops.impl("weak_ref_tensor", torch::kPrivateUse1, &vllm_ascend::weak_ref_tensor);
106108

107109
// Rotary embedding
108110
// Apply GPT-NeoX style rotary embedding to query and key.

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ requires = [
1111
"scipy",
1212
"setuptools>=64",
1313
"setuptools-scm>=8",
14-
"torch_npu",
15-
"torch >= 2.5.1",
14+
"torch_npu==2.5.1rc1",
15+
"torch>=2.5.1",
1616
"torchvision<0.21.0",
1717
]
1818
build-backend = "setuptools.build_meta"

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
-r requirements-lint.txt
2+
-r requirements.txt
23
modelscope
34
pytest >= 6.0
45
pytest-asyncio

requirements.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ cmake>=3.26
33
decorator
44
numpy<2.0.0
55
packaging
6+
pip
67
pybind11
78
pyyaml
89
scipy
910
setuptools>=64
1011
setuptools-scm>=8
11-
torch_npu
12-
torch >= 2.5.1
12+
torch>=2.5.1
1313
torchvision<0.21.0
14+
wheel

tests/compile/__init__.py

Whitespace-only changes.

tests/compile/test_simple.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Test the piecewise compilation with a simple model so that we
4+
can exactly calculate the expected output and side effects.
5+
"""
6+
7+
import pytest
8+
import torch
9+
from torch import nn
10+
from torch.library import Library
11+
from vllm.compilation.counter import compilation_counter
12+
from vllm.compilation.decorators import support_torch_compile
13+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
14+
set_current_vllm_config)
15+
from vllm.utils import direct_register_custom_op
16+
17+
global_counter = 0
18+
19+
# create a library to hold the custom op
20+
silly_lib = Library("silly", "FRAGMENT") # noqa
21+
22+
23+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
24+
out: torch.Tensor) -> None:
25+
global global_counter
26+
global_counter += 1
27+
print(f"{global_counter=}")
28+
out.copy_(q)
29+
out[0] += 1
30+
31+
32+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
33+
out: torch.Tensor) -> None:
34+
return
35+
36+
37+
direct_register_custom_op(
38+
op_name="attention",
39+
op_func=silly_attention,
40+
mutates_args=["out"],
41+
fake_impl=silly_attention_fake,
42+
dispatch_key="PrivateUse1",
43+
target_lib=silly_lib,
44+
)
45+
46+
47+
@support_torch_compile
48+
class SillyModel(nn.Module):
49+
50+
def __init__(self,
51+
*,
52+
vllm_config: VllmConfig,
53+
prefix: str = "",
54+
**kwargs) -> None:
55+
super().__init__()
56+
57+
def forward(self, x: torch.Tensor) -> torch.Tensor:
58+
"""
59+
Overall effect:
60+
x += 1
61+
x[0] += 2
62+
global_counter += 2
63+
"""
64+
x = x + 1
65+
x = x + 2
66+
out = torch.empty_like(x)
67+
torch.ops.silly.attention(x, x, x, out)
68+
x = out
69+
x = x - 2
70+
x = x - 1
71+
out = torch.empty_like(x)
72+
torch.ops.silly.attention(x, x, x, out)
73+
x = out
74+
x = x + 1
75+
return x
76+
77+
78+
@pytest.mark.skipif(True, reason="requires unreleased components")
79+
def test_simple_piecewise_compile():
80+
81+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
82+
level=CompilationLevel.PIECEWISE,
83+
use_inductor=False,
84+
use_cudagraph=True,
85+
splitting_ops=["silly.attention"],
86+
cudagraph_copy_inputs=True,
87+
cudagraph_capture_sizes=[1, 2],
88+
))
89+
vllm_config.compilation_config.pass_config.enable_fusion = False
90+
with set_current_vllm_config(vllm_config):
91+
model = SillyModel(vllm_config=vllm_config, prefix="")
92+
93+
inputs = torch.randn(100).npu()
94+
95+
with compilation_counter.expect(
96+
num_graphs_seen=1, # one graph for the model
97+
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
98+
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
99+
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
100+
num_cudagraph_caputured=
101+
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
102+
):
103+
104+
model(inputs)
105+
106+
model(torch.randn(2).npu())
107+
model(torch.randn(1).npu())
108+
109+
input = torch.zeros(2).npu()
110+
global global_counter
111+
global_counter = 0
112+
output = model(input)
113+
assert global_counter == 2
114+
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
115+
116+
117+
if __name__ == "__main__":
118+
test_simple_piecewise_compile()

tests/multicard/test_offline_inference_distributed.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_models_distributed(model: str,
4747
dtype=dtype,
4848
tensor_parallel_size=4,
4949
distributed_executor_backend=distributed_executor_backend,
50+
enforce_eager=True,
5051
) as vllm_model:
5152
vllm_model.generate_greedy(example_prompts, max_tokens)
5253

tests/singlecard/test_offline_inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
5050
with VllmRunner(model,
5151
max_model_len=8192,
5252
dtype=dtype,
53-
enforce_eager=False,
53+
enforce_eager=True,
5454
gpu_memory_utilization=0.7) as vllm_model:
5555
vllm_model.generate_greedy(example_prompts, max_tokens)
5656

0 commit comments

Comments
 (0)