6
6
class FFN (nn .Module ):
7
7
def __init__ (self , dim : int , n_embed : int , r : int ):
8
8
super ().__init__ ()
9
- # lin1
10
- self .c_fc = Linear (n_embed , dim , r = r , bias = True )
11
- # lin2
12
- self .c_proj = Linear (dim , n_embed , r = r , bias = True )
9
+ self .linear_in = Linear (n_embed , dim , r = r , bias = True )
10
+ self .linear_out = Linear (dim , n_embed , r = r , bias = True )
13
11
self .act = nn .functional .gelu
14
12
15
13
def forward (self , hidden_states ):
16
- hidden_states = self .c_fc (hidden_states )
14
+ hidden_states = self .linear_in (hidden_states )
17
15
hidden_states = self .act (hidden_states )
18
- hidden_states = self .c_proj (hidden_states )
16
+ hidden_states = self .linear_out (hidden_states )
19
17
return hidden_states
20
18
21
19
@@ -27,10 +25,10 @@ def __init__(self, n_embed: int, r: int):
27
25
self .head_dim = self .embed_dim // self .num_heads
28
26
self .split_size = self .embed_dim
29
27
30
- # qkv
31
- self .c_att = Linear (n_embed , n_embed * 3 , r = r , bias = True )
28
+ # query key value
29
+ self .qkv_projection = Linear (n_embed , n_embed * 3 , r = r , bias = True )
32
30
# out
33
- self .c_proj = Linear (n_embed , n_embed , r = r , bias = True )
31
+ self .output_projection = Linear (n_embed , n_embed , r = r , bias = True )
34
32
35
33
def _split_heads (self , tensor , num_heads , attn_head_size ):
36
34
"""
@@ -43,7 +41,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
43
41
def forward (self , hidden_states ):
44
42
batch_size , seq_length , _ = hidden_states .size ()
45
43
46
- query , key , value = self .c_att (hidden_states ).split (self .split_size , dim = 2 )
44
+ query , key , value = self .qkv_projection (hidden_states ).split (self .split_size , dim = 2 )
47
45
48
46
query = self ._split_heads (query , self .num_heads , self .head_dim )
49
47
key = self ._split_heads (key , self .num_heads , self .head_dim )
@@ -61,7 +59,7 @@ def forward(self, hidden_states):
61
59
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
62
60
attn_output = attn_output .view (batch_size , seq_length , self .embed_dim )
63
61
64
- attn_output = self .c_proj (attn_output )
62
+ attn_output = self .output_projection (attn_output )
65
63
66
64
return attn_output
67
65
0 commit comments