1
+ import torch
2
+ from torch .library import Library
3
+
4
+ lib = Library ("_C" , "IMPL" )
5
+
6
+ def register_meta_if_necessary (ns :str , op_name : str , fn , overload : str = "" ):
7
+ if overload != "" :
8
+ op_name = op_name + "." + overload
9
+ schema_to_find = ns + "::" + op_name
10
+ meta_impl_list = torch ._C ._dispatch_get_registrations_for_dispatch_key ("Meta" )
11
+ if schema_to_find in meta_impl_list :
12
+ return
13
+ lib .impl (op_name , fn , "Meta" )
14
+
15
+ def rotary_embedding_meta (
16
+ positions : torch .Tensor ,
17
+ query : torch .Tensor ,
18
+ key : torch .Tensor ,
19
+ head_size : int ,
20
+ cos_sin_cache : torch .Tensor ,
21
+ is_neox : bool ):
22
+
23
+ num_tokens = positions .numel ()
24
+ query_hidden_size = query .numel () / num_tokens
25
+ key_hidden_size = key .numel () / num_tokens
26
+ num_heads = query_hidden_size / head_size
27
+ num_kv_heads = key_hidden_size / head_size
28
+
29
+ query_dst = torch .empty_like (query ).view (num_tokens , num_heads , head_size )
30
+ key_dst = torch .empty_like (key ).view (num_tokens , num_kv_heads , head_size )
31
+ return query_dst , key_dst
32
+
33
+
34
+ def get_masked_input_and_mask_meta (
35
+ input : torch .Tensor ,
36
+ org_vocab_start_index : int ,
37
+ org_vocab_end_index : int ,
38
+ num_org_vocab_padding : int ,
39
+ added_vocab_start_index : int ,
40
+ added_vocab_end_index : int ):
41
+
42
+ masked_input = torch .empty_like (input )
43
+ mask = torch .empty_like (input ).to (torch .bool )
44
+
45
+ return masked_input , mask
46
+
47
+
48
+
49
+ register_meta_if_necessary ("_C" , "rotary_embedding" , rotary_embedding_meta )
50
+ register_meta_if_necessary ("_C" , "get_masked_input_and_mask" , get_masked_input_and_mask_meta )
0 commit comments