Skip to content

Commit 474bdc4

Browse files
committed
fix format issue
Signed-off-by: ganyi <[email protected]>
1 parent b9e5069 commit 474bdc4

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

csrc/torch_binding_meta.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
2020
auto query_hidden_size = query.sym_numel() / num_tokens;
2121
auto key_hidden_size = key.sym_numel() / num_tokens;
2222

23-
// Make sure query and key have consistent number of heads
2423
auto num_heads = query_hidden_size / head_size;
2524
auto num_kv_heads = key_hidden_size / head_size;
2625
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());

tests/e2e/singlecard/ops/test_rotary_embedding.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ def test_rotary_embedding_quant_with_leading_dim(
201201
rtol=DEFAULT_RTOL)
202202

203203

204-
205204
class ModelwithRotaryEmbedding(nn.Module):
205+
206206
def __init__(
207207
self,
208208
hidden_size: int,
@@ -212,7 +212,8 @@ def __init__(
212212
max_position_embeddings: int,
213213
base: int,
214214
is_neox_style: bool,
215-
dtype: torch.dtype,) -> None:
215+
dtype: torch.dtype,
216+
) -> None:
216217
super().__init__()
217218
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
218219
self.rope = RotaryEmbedding(
@@ -231,12 +232,13 @@ def forward(
231232
hidden_states: torch.Tensor,
232233
offsets: Optional[torch.Tensor] = None,
233234
) -> torch.Tensor:
234-
# we simulated a simple attention layer to test if it can seamlessly integrated into aclgraph
235-
q,k,v = self.qkv_proj(hidden_states).chunk(3, dim=-1)
235+
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
236+
q, k, v = self.qkv_proj(hidden_states).chunk(3, dim=-1)
236237
query, key = self.rope.forward_native(positions, q, k, offsets)
237238
o = self.o_proj(query)
238239
return o
239240

241+
240242
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
241243
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
242244
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -272,17 +274,21 @@ def test_capture_rotary_embedding_in_aclgraph(
272274
dtype=dtype,
273275
)
274276

275-
static_positions = torch.randint(0, max_position_embeddings, (num_tokens, ))
276-
static_hidden_states = torch.randn(num_tokens, num_heads * head_size,
277-
dtype=dtype, device="npu")
277+
static_positions = torch.randint(0, max_position_embeddings,
278+
(num_tokens, ))
279+
static_hidden_states = torch.randn(num_tokens,
280+
num_heads * head_size,
281+
dtype=dtype,
282+
device="npu")
278283
compiled_model = torch.compile(model)
279284
stream = torch.npu.Stream()
280285
stream.wait_stream(torch.npu.current_stream())
281286
with torch.npu.stream(stream):
282287
# warmup the fx graph before capture
283288
for i in range(3):
284-
static_output = compiled_model(
285-
static_positions, static_hidden_states, offsets=None)
289+
static_output = compiled_model(static_positions,
290+
static_hidden_states,
291+
offsets=None)
286292
stream.wait_stream(torch.npu.current_stream())
287293

288294
aclgraph = torch.npu.NPUGraph()
@@ -291,9 +297,14 @@ def test_capture_rotary_embedding_in_aclgraph(
291297
# Capture the model in aclgraph.
292298
static_output = compiled_model(static_positions, static_hidden_states)
293299
# Capture the model in aclgraph.
294-
random_filled_positions = torch.randint(
295-
0, max_position_embeddings, (num_tokens, ), device="npu")
296-
random_filled_hidden_states = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="npu")
300+
random_filled_positions = torch.randint(0,
301+
max_position_embeddings,
302+
(num_tokens, ),
303+
device="npu")
304+
random_filled_hidden_states = torch.randn(num_tokens,
305+
num_heads * head_size,
306+
dtype=dtype,
307+
device="npu")
297308
static_positions.copy_(random_filled_positions)
298309
static_hidden_states.copy_(random_filled_hidden_states)
299310

@@ -303,4 +314,3 @@ def test_capture_rotary_embedding_in_aclgraph(
303314
output_reference,
304315
atol=DEFAULT_ATOL,
305316
rtol=DEFAULT_RTOL)
306-

0 commit comments

Comments
 (0)