Skip to content

Commit ee85cae

Browse files
Update phishing_email_detection_gpt2.py
Try Irope layer.
1 parent 4629394 commit ee85cae

File tree

1 file changed

+56
-30
lines changed

1 file changed

+56
-30
lines changed

phishing_email_detection_gpt2.py

+56-30
Original file line numberDiff line numberDiff line change
@@ -210,46 +210,72 @@ def from_config(cls, config):
210210

211211

212212

213-
class RotaryPositionEmbedding(tf.keras.layers.Layer):
214-
def __init__(self, max_seq_length, d_model, **kwargs):
213+
import tensorflow as tf
214+
from tensorflow import keras
215+
from tensorflow.keras import layers
216+
217+
class RotaryEmbedding(keras.layers.Layer):
218+
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
215219
super().__init__(**kwargs)
216-
self.max_seq_length = max_seq_length
217-
self.d_model = d_model
218-
assert d_model % 2 == 0, "d_model must be even"
219-
220-
# Precompute rotation matrices
221-
inv_freq = 1.0 / (10000 ** (tf.range(0, d_model, 2, dtype=tf.float32) / d_model))
222-
self.inv_freq = tf.cast(inv_freq, tf.float32)
223-
positions = tf.range(max_seq_length, dtype=tf.float32)
224-
self.sin = tf.sin(tf.einsum('i,j->ij', positions, inv_freq))
225-
self.cos = tf.cos(tf.einsum('i,j->ij', positions, inv_freq))
226-
220+
self.dim = dim
221+
self.max_seq_len = max_seq_len
222+
self.temperature = temperature
223+
224+
def build(self, input_shape):
225+
super().build(input_shape)
226+
inv_freq = 1.0 / (self.temperature ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
227+
position = tf.range(self.max_seq_len, dtype=tf.float32)
228+
sinusoid = tf.einsum("i,j->ij", position, inv_freq)
229+
self.sin_cache = tf.sin(sinusoid)
230+
self.cos_cache = tf.cos(sinusoid)
231+
232+
def call(self, x, seq_len=None):
233+
batch_size = tf.shape(x)[0]
234+
seq_len = tf.shape(x)[1] if seq_len is None else seq_len
235+
sin = self.sin_cache[:seq_len]
236+
cos = self.cos_cache[:seq_len]
237+
return tf.cast(sin, x.dtype), tf.cast(cos, x.dtype)
238+
239+
def split_alternate(x):
240+
shape = tf.shape(x)
241+
x = tf.reshape(x, [shape[0], shape[1], shape[2] // 2, 2])
242+
x = tf.transpose(x, [0, 1, 3, 2])
243+
x = tf.reshape(x, [shape[0], shape[1], -1])
244+
return x
245+
246+
def rotate_half(x):
247+
x = split_alternate(x)
248+
d = x.shape[-1]
249+
return x[..., d//2:]
250+
251+
def apply_rotary_pos_emb(x, sin, cos):
252+
x_rotated = x * cos + rotate_half(x) * sin
253+
return x_rotated
254+
255+
class InterleavedRoPE(layers.Layer):
256+
def __init__(self, dim, max_seq_len=1024, **kwargs):
257+
super().__init__(**kwargs)
258+
self.dim = dim
259+
self.max_seq_len = max_seq_len
260+
self.rotary_emb = RotaryEmbedding(dim, max_seq_len)
261+
227262
def call(self, x):
228263
batch_size = tf.shape(x)[0]
229264
seq_len = tf.shape(x)[1]
230265

231-
# Compute sine and cosine matrices for current sequence length
232-
sinusoid = tf.einsum('i,j->ij', tf.range(seq_len, dtype=tf.float32), self.inv_freq)
233-
current_sin = tf.sin(sinusoid)
234-
current_cos = tf.cos(sinusoid)
235-
236-
# Split dimensions and apply rotation using einsum
237-
x = tf.reshape(x, [batch_size, seq_len, self.d_model//2, 2])
238-
rotated = tf.stack([
239-
x[..., 0] * current_cos - x[..., 1] * current_sin,
240-
x[..., 0] * current_sin + x[..., 1] * current_cos
241-
], axis=-1)
242-
243-
# Reshape back and apply dropout
244-
return tf.reshape(rotated, [batch_size, seq_len, self.d_model])
266+
sin, cos = self.rotary_emb(x, seq_len)
267+
x = apply_rotary_pos_emb(x, sin, cos)
268+
return x
269+
270+
245271

246272

247273

248274

249275
# GPT2 configurables
250276

251277
# Optimal for accuracy thus far:
252-
max_seq_length = 1024
278+
max_seq_length = 1024 * 2
253279

254280
inp = tf.keras.layers.Input(shape=(), dtype=tf.string)
255281
gp2_tokenizer = TokenizerLayer(max_seq_length=max_seq_length)
@@ -267,9 +293,9 @@ def call(self, x):
267293
input_length=max_seq_length,
268294
mask_zero=True)(tokens)
269295

270-
position_embedding = RotaryPositionEmbedding(
296+
position_embedding = InterleavedRoPE(
297+
dim=EMBEDDING_DIM,
271298
max_seq_length=max_seq_length,
272-
d_model=EMBEDDING_DIM,
273299
# initializer="uniform",
274300
)(embedded)
275301

0 commit comments

Comments
 (0)