Skip to content

Commit 3386e09

Browse files
authored
ut:add ut for qwen2_vl.py (#2096)
### What this PR does / why we need it? add ut for qwen2_vl.py ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? not involved - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@555e722 Signed-off-by: Ronald1995 <[email protected]>
1 parent 936df1c commit 3386e09

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

tests/ut/models/test_qwen2_vl.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import pytest
2+
import torch
3+
from pytest_mock import MockerFixture
4+
from vllm.model_executor.layers.activation import QuickGELU
5+
6+
from tests.ut.base import PytestBase
7+
from vllm_ascend.models.qwen2_vl import (AscendQwen2VisionAttention,
8+
AscendQwen2VisionBlock)
9+
10+
11+
class TestAscendQwen2VisionAttention(PytestBase):
12+
13+
def init_attention(
14+
self,
15+
mocker,
16+
embed_dim=1000,
17+
num_heads=10,
18+
projection_size=100,
19+
quant_config=None,
20+
prefix="",
21+
):
22+
mocker_attn = mocker.patch(
23+
"vllm_ascend.models.qwen2_vl.Qwen2VisionAttention.__init__")
24+
25+
attention = AscendQwen2VisionAttention(
26+
embed_dim=embed_dim,
27+
num_heads=num_heads,
28+
projection_size=projection_size,
29+
quant_config=quant_config,
30+
prefix=prefix,
31+
)
32+
args, kwargs = mocker_attn.call_args
33+
assert args == (embed_dim, num_heads, projection_size, None, "")
34+
assert not kwargs
35+
attention.num_attention_heads_per_partition = num_heads
36+
return attention
37+
38+
def test_attn_init_should_normal(self, mocker: MockerFixture):
39+
embed_dim = 1000
40+
num_heads = 10
41+
projection_size = 100
42+
quant_config = None
43+
prefix = ""
44+
vit = self.init_attention(
45+
embed_dim=embed_dim,
46+
num_heads=num_heads,
47+
projection_size=projection_size,
48+
quant_config=quant_config,
49+
prefix=prefix,
50+
mocker=mocker,
51+
)
52+
assert vit.hidden_size_per_attention_head == 10
53+
54+
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
55+
embed_dim = 1000
56+
num_heads = 7
57+
projection_size = 100
58+
quant_config = None
59+
prefix = ""
60+
with pytest.raises(AssertionError):
61+
# projection_size should divided by num heads
62+
self.init_attention(
63+
mocker=mocker,
64+
embed_dim=embed_dim,
65+
num_heads=num_heads,
66+
projection_size=projection_size,
67+
quant_config=quant_config,
68+
prefix=prefix,
69+
)
70+
71+
def test_attn_forward(self, mocker: MockerFixture):
72+
attention = self.init_attention(mocker=mocker)
73+
mocker.patch("torch.nn.Module.__setattr__")
74+
mocker.patch("torch.nn.Module.__getattr__")
75+
mocker.patch("torch.nn.Module.__delattr__")
76+
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
77+
cu_seqlens = torch.tensor([10, 50, 100])
78+
cos = torch.rand((1, 100, 1, 128))
79+
sin = torch.rand((1, 100, 1, 128))
80+
81+
qkv = lambda x: (x, 0) # noqa
82+
split_qkv = lambda x: [ #noqa
83+
torch.rand((100, 3, 10, 128)) for i in range(3)
84+
] # noqa
85+
npu_rotary_mul = lambda q, cos, sin: q # noqa
86+
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
87+
proj = lambda x: (x, 0) # noqa
88+
89+
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
90+
mocker_split_qkv = mocker.patch.object(
91+
attention,
92+
"split_qkv",
93+
side_effect=split_qkv,
94+
)
95+
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
96+
side_effect=npu_rotary_mul)
97+
mocker_npu_flash_attention_unpad = mocker.patch(
98+
"torch_npu._npu_flash_attention_unpad",
99+
side_effect=_npu_flash_attention_unpad,
100+
)
101+
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
102+
attention.__dict__["qkv"] = mocker_qkv
103+
attention.__dict__["split_qkv"] = mocker_split_qkv
104+
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
105+
attention.__dict__["_npu_flash_attention_unpad"] = (
106+
mocker_npu_flash_attention_unpad)
107+
attention.__dict__["proj"] = mocker_proj
108+
109+
output = attention.forward(
110+
x=x,
111+
cu_seqlens=cu_seqlens,
112+
cos=cos,
113+
sin=sin,
114+
)
115+
qkv_args, qkv_kwargs = mocker_qkv.call_args
116+
assert qkv_args == (x, )
117+
assert not qkv_kwargs
118+
119+
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
120+
assert split_qkv_args == (x, )
121+
assert not split_qkv_kwargs
122+
123+
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
124+
assert npu_rotary_mul_args[1:] == (cos, sin)
125+
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
126+
assert not npu_rotary_mul_kwargs
127+
128+
assert output.shape == torch.Size([100, 3, 1280])
129+
130+
131+
class TestAscendQwen2VisionBlock(PytestBase):
132+
133+
def init_vision_block(
134+
self,
135+
mocker,
136+
dim=100,
137+
num_heads=10,
138+
mlp_ratio=0.5,
139+
):
140+
mocker_vit = mocker.patch(
141+
"vllm.model_executor.models.qwen2_vl.Qwen2VisionBlock.__init__",
142+
return_value=None,
143+
)
144+
145+
mocker_attn = mocker.patch(
146+
"vllm_ascend.models.qwen2_vl.AscendQwen2VisionAttention.__init__",
147+
return_value=None,
148+
)
149+
150+
mocker.patch("torch.nn.Module.__setattr__")
151+
mocker.patch("torch.nn.Module.__getattr__")
152+
mocker.patch("torch.nn.Module.__delattr__")
153+
vision_block = AscendQwen2VisionBlock(
154+
dim=dim,
155+
num_heads=num_heads,
156+
mlp_ratio=mlp_ratio,
157+
)
158+
args, kwargs = mocker_vit.call_args
159+
assert args == (dim, num_heads, mlp_ratio, QuickGELU, None, None, "")
160+
assert not kwargs
161+
162+
args1, kwargs1 = mocker_attn.call_args
163+
assert not args1
164+
assert kwargs1 == {
165+
"embed_dim": dim,
166+
"num_heads": num_heads,
167+
"projection_size": dim,
168+
"quant_config": None,
169+
"prefix": ".attn",
170+
}
171+
return vision_block
172+
173+
def test_init_vision_block_should_normal(
174+
self,
175+
mocker: MockerFixture,
176+
):
177+
vision_block = self.init_vision_block(mocker)
178+
assert isinstance(vision_block, AscendQwen2VisionBlock)
179+
180+
def test_vision_block_forward(self, mocker: MockerFixture):
181+
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
182+
cu_seqlens = torch.tensor([10, 50, 100])
183+
cos = torch.rand((1, 100, 1, 128))
184+
sin = torch.rand((1, 100, 1, 128))
185+
vision_block = self.init_vision_block(mocker)
186+
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
187+
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
188+
vision_block.__dict__["attn"] = mocker_attn
189+
vision_block.__dict__["mlp"] = mocker_mlp
190+
191+
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
192+
193+
_, attn_kwargs = mocker_attn.call_args
194+
assert attn_kwargs == {
195+
"cu_seqlens": cu_seqlens,
196+
"cos": cos,
197+
"sin": sin,
198+
}
199+
200+
assert torch.all(x * 3 == output)

0 commit comments

Comments
 (0)