@@ -201,8 +201,8 @@ def test_rotary_embedding_quant_with_leading_dim(
201
201
rtol = DEFAULT_RTOL )
202
202
203
203
204
-
205
204
class ModelwithRotaryEmbedding (nn .Module ):
205
+
206
206
def __init__ (
207
207
self ,
208
208
hidden_size : int ,
@@ -212,7 +212,8 @@ def __init__(
212
212
max_position_embeddings : int ,
213
213
base : int ,
214
214
is_neox_style : bool ,
215
- dtype : torch .dtype ,) -> None :
215
+ dtype : torch .dtype ,
216
+ ) -> None :
216
217
super ().__init__ ()
217
218
self .qkv_proj = nn .Linear (hidden_size , num_heads * head_size * 3 )
218
219
self .rope = RotaryEmbedding (
@@ -231,12 +232,13 @@ def forward(
231
232
hidden_states : torch .Tensor ,
232
233
offsets : Optional [torch .Tensor ] = None ,
233
234
) -> torch .Tensor :
234
- # we simulated a simple attention layer to test if it can seamlessly integrated into aclgraph
235
- q ,k , v = self .qkv_proj (hidden_states ).chunk (3 , dim = - 1 )
235
+ # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
236
+ q , k , v = self .qkv_proj (hidden_states ).chunk (3 , dim = - 1 )
236
237
query , key = self .rope .forward_native (positions , q , k , offsets )
237
238
o = self .o_proj (query )
238
239
return o
239
240
241
+
240
242
@pytest .mark .parametrize ("is_neox_style" , IS_NEOX_STYLE )
241
243
@pytest .mark .parametrize ("num_tokens" , BATCH_SIZES )
242
244
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
@@ -272,17 +274,21 @@ def test_capture_rotary_embedding_in_aclgraph(
272
274
dtype = dtype ,
273
275
)
274
276
275
- static_positions = torch .randint (0 , max_position_embeddings , (num_tokens , ))
276
- static_hidden_states = torch .randn (num_tokens , num_heads * head_size ,
277
- dtype = dtype , device = "npu" )
277
+ static_positions = torch .randint (0 , max_position_embeddings ,
278
+ (num_tokens , ))
279
+ static_hidden_states = torch .randn (num_tokens ,
280
+ num_heads * head_size ,
281
+ dtype = dtype ,
282
+ device = "npu" )
278
283
compiled_model = torch .compile (model )
279
284
stream = torch .npu .Stream ()
280
285
stream .wait_stream (torch .npu .current_stream ())
281
286
with torch .npu .stream (stream ):
282
287
# warmup the fx graph before capture
283
288
for i in range (3 ):
284
- static_output = compiled_model (
285
- static_positions , static_hidden_states , offsets = None )
289
+ static_output = compiled_model (static_positions ,
290
+ static_hidden_states ,
291
+ offsets = None )
286
292
stream .wait_stream (torch .npu .current_stream ())
287
293
288
294
aclgraph = torch .npu .NPUGraph ()
@@ -291,9 +297,14 @@ def test_capture_rotary_embedding_in_aclgraph(
291
297
# Capture the model in aclgraph.
292
298
static_output = compiled_model (static_positions , static_hidden_states )
293
299
# Capture the model in aclgraph.
294
- random_filled_positions = torch .randint (
295
- 0 , max_position_embeddings , (num_tokens , ), device = "npu" )
296
- random_filled_hidden_states = torch .randn (num_tokens , num_heads * head_size , dtype = dtype , device = "npu" )
300
+ random_filled_positions = torch .randint (0 ,
301
+ max_position_embeddings ,
302
+ (num_tokens , ),
303
+ device = "npu" )
304
+ random_filled_hidden_states = torch .randn (num_tokens ,
305
+ num_heads * head_size ,
306
+ dtype = dtype ,
307
+ device = "npu" )
297
308
static_positions .copy_ (random_filled_positions )
298
309
static_hidden_states .copy_ (random_filled_hidden_states )
299
310
@@ -303,4 +314,3 @@ def test_capture_rotary_embedding_in_aclgraph(
303
314
output_reference ,
304
315
atol = DEFAULT_ATOL ,
305
316
rtol = DEFAULT_RTOL )
306
-
0 commit comments