Skip to content

Commit cf755ec

Browse files
authored
Merge pull request #271 from lakshith-403/LoRA
LoRA minor updates
2 parents f3465ac + 3349afd commit cf755ec

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

labml_nn/lora/experiment.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,16 @@ def _load_pretrained_weights(self):
7676
for i in range(12):
7777
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
7878
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
79-
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
80-
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
81-
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
82-
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
79+
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
80+
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
81+
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
82+
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
8383
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
8484
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
85-
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
86-
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
87-
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
88-
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
85+
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
86+
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
87+
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
88+
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.linear_out.bias'
8989

9090
# Move the parameters based on mapping
9191
new_state_dict = {}
@@ -94,10 +94,10 @@ def _load_pretrained_weights(self):
9494
new_state_dict[new_key] = state_dict[old_key]
9595

9696
# GPT-2 hugging face uses 1D Convolution layers. We need to transpose those weights since we use linear layers
97-
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
98-
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
99-
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
100-
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
97+
convo_layers = ([f'blocks.{i}.ffn.linear_in.weight' for i in range(12)] +
98+
[f'blocks.{i}.ffn.linear_out.weight' for i in range(12)] +
99+
[f'blocks.{i}.attn.qkv_projection.weight' for i in range(12)] +
100+
[f'blocks.{i}.attn.output_projection.weight' for i in range(12)])
101101

102102
for layer in convo_layers:
103103
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
@@ -134,8 +134,7 @@ def run(self):
134134
"""
135135

136136
for _ in monit.loop(self.epochs):
137-
for i, batch in monit.enum('Train', self.data_loader):
138-
inputs = batch[0]
137+
for (inputs, ) in monit.iterate('Train', self.data_loader):
139138
inputs = inputs.to(self.device)
140139
labels = inputs.clone()
141140

labml_nn/lora/gpt2.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@
66
class FFN(nn.Module):
77
def __init__(self, dim: int, n_embed: int, r: int):
88
super().__init__()
9-
# lin1
10-
self.c_fc = Linear(n_embed, dim, r=r, bias=True)
11-
# lin2
12-
self.c_proj = Linear(dim, n_embed, r=r, bias=True)
9+
self.linear_in = Linear(n_embed, dim, r=r, bias=True)
10+
self.linear_out = Linear(dim, n_embed, r=r, bias=True)
1311
self.act = nn.functional.gelu
1412

1513
def forward(self, hidden_states):
16-
hidden_states = self.c_fc(hidden_states)
14+
hidden_states = self.linear_in(hidden_states)
1715
hidden_states = self.act(hidden_states)
18-
hidden_states = self.c_proj(hidden_states)
16+
hidden_states = self.linear_out(hidden_states)
1917
return hidden_states
2018

2119

@@ -27,10 +25,10 @@ def __init__(self, n_embed: int, r: int):
2725
self.head_dim = self.embed_dim // self.num_heads
2826
self.split_size = self.embed_dim
2927

30-
# qkv
31-
self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)
28+
# query key value
29+
self.qkv_projection = Linear(n_embed, n_embed * 3, r=r, bias=True)
3230
# out
33-
self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)
31+
self.output_projection = Linear(n_embed, n_embed, r=r, bias=True)
3432

3533
def _split_heads(self, tensor, num_heads, attn_head_size):
3634
"""
@@ -43,7 +41,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
4341
def forward(self, hidden_states):
4442
batch_size, seq_length, _ = hidden_states.size()
4543

46-
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
44+
query, key, value = self.qkv_projection(hidden_states).split(self.split_size, dim=2)
4745

4846
query = self._split_heads(query, self.num_heads, self.head_dim)
4947
key = self._split_heads(key, self.num_heads, self.head_dim)
@@ -61,7 +59,7 @@ def forward(self, hidden_states):
6159
attn_output = attn_output.transpose(1, 2).contiguous()
6260
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
6361

64-
attn_output = self.c_proj(attn_output)
62+
attn_output = self.output_projection(attn_output)
6563

6664
return attn_output
6765

0 commit comments

Comments
 (0)