Skip to content

Commit 6e40e45

Browse files
committed
add python level meta registration method
Signed-off-by: ganyi <[email protected]>
1 parent ae73b4f commit 6e40e45

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

vllm_ascend/ops/meta_registration.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
from torch.library import Library
3+
4+
lib = Library("_C", "IMPL")
5+
6+
def register_meta_if_necessary(ns:str, op_name: str, fn, overload: str = ""):
7+
if overload != "":
8+
op_name = op_name + "." + overload
9+
schema_to_find = ns + "::" + op_name
10+
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key("Meta")
11+
if schema_to_find in meta_impl_list:
12+
return
13+
lib.impl(op_name, fn, "Meta")
14+
15+
def rotary_embedding_meta(
16+
positions: torch.Tensor,
17+
query: torch.Tensor,
18+
key: torch.Tensor,
19+
head_size: int,
20+
cos_sin_cache: torch.Tensor,
21+
is_neox: bool):
22+
23+
num_tokens = positions.numel()
24+
query_hidden_size = query.numel() / num_tokens
25+
key_hidden_size = key.numel() / num_tokens
26+
num_heads = query_hidden_size / head_size
27+
num_kv_heads = key_hidden_size / head_size
28+
29+
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
30+
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
31+
return query_dst, key_dst
32+
33+
34+
def get_masked_input_and_mask_meta(
35+
input: torch.Tensor,
36+
org_vocab_start_index: int,
37+
org_vocab_end_index: int,
38+
num_org_vocab_padding: int,
39+
added_vocab_start_index: int,
40+
added_vocab_end_index: int):
41+
42+
masked_input = torch.empty_like(input)
43+
mask = torch.empty_like(input).to(torch.bool)
44+
45+
return masked_input, mask
46+
47+
48+
49+
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
50+
register_meta_if_necessary("_C", "get_masked_input_and_mask", get_masked_input_and_mask_meta)

vllm_ascend/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def enable_custom_op():
216216
try:
217217
# register custom ops into torch_library here
218218
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
219+
# register the meta implementation for custom kernel if necessary
220+
import vllm_ascend.ops.meta_registration # type: ignore # noqa: F401
219221
_CUSTOM_OP_ENABLED = True
220222
except ImportError:
221223
_CUSTOM_OP_ENABLED = False

0 commit comments

Comments
 (0)