@@ -210,46 +210,72 @@ def from_config(cls, config):
210
210
211
211
212
212
213
- class RotaryPositionEmbedding (tf .keras .layers .Layer ):
214
- def __init__ (self , max_seq_length , d_model , ** kwargs ):
213
+ import tensorflow as tf
214
+ from tensorflow import keras
215
+ from tensorflow .keras import layers
216
+
217
+ class RotaryEmbedding (keras .layers .Layer ):
218
+ def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
215
219
super ().__init__ (** kwargs )
216
- self .max_seq_length = max_seq_length
217
- self .d_model = d_model
218
- assert d_model % 2 == 0 , "d_model must be even"
219
-
220
- # Precompute rotation matrices
221
- inv_freq = 1.0 / (10000 ** (tf .range (0 , d_model , 2 , dtype = tf .float32 ) / d_model ))
222
- self .inv_freq = tf .cast (inv_freq , tf .float32 )
223
- positions = tf .range (max_seq_length , dtype = tf .float32 )
224
- self .sin = tf .sin (tf .einsum ('i,j->ij' , positions , inv_freq ))
225
- self .cos = tf .cos (tf .einsum ('i,j->ij' , positions , inv_freq ))
226
-
220
+ self .dim = dim
221
+ self .max_seq_len = max_seq_len
222
+ self .temperature = temperature
223
+
224
+ def build (self , input_shape ):
225
+ super ().build (input_shape )
226
+ inv_freq = 1.0 / (self .temperature ** (tf .range (0 , self .dim , 2 , dtype = tf .float32 ) / self .dim ))
227
+ position = tf .range (self .max_seq_len , dtype = tf .float32 )
228
+ sinusoid = tf .einsum ("i,j->ij" , position , inv_freq )
229
+ self .sin_cache = tf .sin (sinusoid )
230
+ self .cos_cache = tf .cos (sinusoid )
231
+
232
+ def call (self , x , seq_len = None ):
233
+ batch_size = tf .shape (x )[0 ]
234
+ seq_len = tf .shape (x )[1 ] if seq_len is None else seq_len
235
+ sin = self .sin_cache [:seq_len ]
236
+ cos = self .cos_cache [:seq_len ]
237
+ return tf .cast (sin , x .dtype ), tf .cast (cos , x .dtype )
238
+
239
+ def split_alternate (x ):
240
+ shape = tf .shape (x )
241
+ x = tf .reshape (x , [shape [0 ], shape [1 ], shape [2 ] // 2 , 2 ])
242
+ x = tf .transpose (x , [0 , 1 , 3 , 2 ])
243
+ x = tf .reshape (x , [shape [0 ], shape [1 ], - 1 ])
244
+ return x
245
+
246
+ def rotate_half (x ):
247
+ x = split_alternate (x )
248
+ d = x .shape [- 1 ]
249
+ return x [..., d // 2 :]
250
+
251
+ def apply_rotary_pos_emb (x , sin , cos ):
252
+ x_rotated = x * cos + rotate_half (x ) * sin
253
+ return x_rotated
254
+
255
+ class InterleavedRoPE (layers .Layer ):
256
+ def __init__ (self , dim , max_seq_len = 1024 , ** kwargs ):
257
+ super ().__init__ (** kwargs )
258
+ self .dim = dim
259
+ self .max_seq_len = max_seq_len
260
+ self .rotary_emb = RotaryEmbedding (dim , max_seq_len )
261
+
227
262
def call (self , x ):
228
263
batch_size = tf .shape (x )[0 ]
229
264
seq_len = tf .shape (x )[1 ]
230
265
231
- # Compute sine and cosine matrices for current sequence length
232
- sinusoid = tf .einsum ('i,j->ij' , tf .range (seq_len , dtype = tf .float32 ), self .inv_freq )
233
- current_sin = tf .sin (sinusoid )
234
- current_cos = tf .cos (sinusoid )
235
-
236
- # Split dimensions and apply rotation using einsum
237
- x = tf .reshape (x , [batch_size , seq_len , self .d_model // 2 , 2 ])
238
- rotated = tf .stack ([
239
- x [..., 0 ] * current_cos - x [..., 1 ] * current_sin ,
240
- x [..., 0 ] * current_sin + x [..., 1 ] * current_cos
241
- ], axis = - 1 )
242
-
243
- # Reshape back and apply dropout
244
- return tf .reshape (rotated , [batch_size , seq_len , self .d_model ])
266
+ sin , cos = self .rotary_emb (x , seq_len )
267
+ x = apply_rotary_pos_emb (x , sin , cos )
268
+ return x
269
+
270
+
245
271
246
272
247
273
248
274
249
275
# GPT2 configurables
250
276
251
277
# Optimal for accuracy thus far:
252
- max_seq_length = 1024
278
+ max_seq_length = 1024 * 2
253
279
254
280
inp = tf .keras .layers .Input (shape = (), dtype = tf .string )
255
281
gp2_tokenizer = TokenizerLayer (max_seq_length = max_seq_length )
@@ -267,9 +293,9 @@ def call(self, x):
267
293
input_length = max_seq_length ,
268
294
mask_zero = True )(tokens )
269
295
270
- position_embedding = RotaryPositionEmbedding (
296
+ position_embedding = InterleavedRoPE (
297
+ dim = EMBEDDING_DIM ,
271
298
max_seq_length = max_seq_length ,
272
- d_model = EMBEDDING_DIM ,
273
299
# initializer="uniform",
274
300
)(embedded )
275
301
0 commit comments