@@ -233,8 +233,18 @@ def forward(
233
233
offsets : Optional [torch .Tensor ] = None ,
234
234
) -> torch .Tensor :
235
235
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
236
- q , k , v = self .qkv_proj (hidden_states ).chunk (3 , dim = - 1 )
237
- query , key = self .rope .forward_native (positions , q , k , offsets )
236
+ qkv = self .qkv_proj (hidden_states )
237
+ q , k , v = qkv .chunk (3 , dim = - 1 )
238
+ query , key = torch .ops ._C .rotary_embedding (
239
+ positions ,
240
+ q ,
241
+ k ,
242
+ self .rope .head_size ,
243
+ self .rope .cos_sin_cache ,
244
+ self .rope .is_neox_style ,
245
+ )
246
+ query = query .view (q .shape )
247
+ key = key .view (k .shape )
238
248
o = self .o_proj (query )
239
249
return o
240
250
@@ -257,14 +267,16 @@ def test_capture_rotary_embedding_in_aclgraph(
257
267
dtype : torch .dtype ,
258
268
seed : int ,
259
269
device : str ,
260
- max_position_embeddings : int ,
261
- base : int ,
270
+ max_position_embeddings : int = 8192 ,
271
+ base : int = 10000 ,
262
272
):
263
273
"""Test if the rotary embedding can be captured in aclgraph."""
264
274
torch .manual_seed (seed )
265
275
torch .set_default_device (device )
276
+ if rotary_dim is None :
277
+ rotary_dim = head_size
266
278
model = ModelwithRotaryEmbedding (
267
- hidden_size = num_tokens ,
279
+ hidden_size = num_heads * head_size ,
268
280
num_heads = num_heads ,
269
281
head_size = head_size ,
270
282
rotary_dim = rotary_dim ,
@@ -274,13 +286,20 @@ def test_capture_rotary_embedding_in_aclgraph(
274
286
dtype = dtype ,
275
287
)
276
288
289
+ def custom_op_checking_backend (gm : torch .fx .GraphModule , example_input ):
290
+ # Validate if the rotary_embedding custom kernel is indeed inside the graph by
291
+ # string match
292
+ graph = str (gm .graph )
293
+ assert "_C.rotary_embedding" in graph
294
+ return gm
295
+
277
296
static_positions = torch .randint (0 , max_position_embeddings ,
278
297
(num_tokens , ))
279
298
static_hidden_states = torch .randn (num_tokens ,
280
299
num_heads * head_size ,
281
300
dtype = dtype ,
282
301
device = "npu" )
283
- compiled_model = torch .compile (model )
302
+ compiled_model = torch .compile (model , backend = custom_op_checking_backend )
284
303
stream = torch .npu .Stream ()
285
304
stream .wait_stream (torch .npu .current_stream ())
286
305
with torch .npu .stream (stream ):
0 commit comments