Skip to content

Commit cf0022c

Browse files
committed
fix the unittest
Signed-off-by: ganyi <[email protected]>
1 parent 0baa807 commit cf0022c

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,18 @@ def forward(
233233
offsets: Optional[torch.Tensor] = None,
234234
) -> torch.Tensor:
235235
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
236-
q, k, v = self.qkv_proj(hidden_states).chunk(3, dim=-1)
237-
query, key = self.rope.forward_native(positions, q, k, offsets)
236+
qkv = self.qkv_proj(hidden_states)
237+
q, k, v = qkv.chunk(3, dim=-1)
238+
query, key = torch.ops._C.rotary_embedding(
239+
positions,
240+
q,
241+
k,
242+
self.rope.head_size,
243+
self.rope.cos_sin_cache,
244+
self.rope.is_neox_style,
245+
)
246+
query = query.view(q.shape)
247+
key = key.view(k.shape)
238248
o = self.o_proj(query)
239249
return o
240250

@@ -257,14 +267,16 @@ def test_capture_rotary_embedding_in_aclgraph(
257267
dtype: torch.dtype,
258268
seed: int,
259269
device: str,
260-
max_position_embeddings: int,
261-
base: int,
270+
max_position_embeddings: int = 8192,
271+
base: int = 10000,
262272
):
263273
"""Test if the rotary embedding can be captured in aclgraph."""
264274
torch.manual_seed(seed)
265275
torch.set_default_device(device)
276+
if rotary_dim is None:
277+
rotary_dim = head_size
266278
model = ModelwithRotaryEmbedding(
267-
hidden_size=num_tokens,
279+
hidden_size=num_heads * head_size,
268280
num_heads=num_heads,
269281
head_size=head_size,
270282
rotary_dim=rotary_dim,
@@ -274,13 +286,20 @@ def test_capture_rotary_embedding_in_aclgraph(
274286
dtype=dtype,
275287
)
276288

289+
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
290+
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
291+
# string match
292+
graph = str(gm.graph)
293+
assert "_C.rotary_embedding" in graph
294+
return gm
295+
277296
static_positions = torch.randint(0, max_position_embeddings,
278297
(num_tokens, ))
279298
static_hidden_states = torch.randn(num_tokens,
280299
num_heads * head_size,
281300
dtype=dtype,
282301
device="npu")
283-
compiled_model = torch.compile(model)
302+
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
284303
stream = torch.npu.Stream()
285304
stream.wait_stream(torch.npu.current_stream())
286305
with torch.npu.stream(stream):

0 commit comments

Comments
 (0)