diff --git a/NEZHA-PyTorch/modeling_nezha.py b/NEZHA-PyTorch/modeling_nezha.py index 1df38e53..82f04607 100644 --- a/NEZHA-PyTorch/modeling_nezha.py +++ b/NEZHA-PyTorch/modeling_nezha.py @@ -322,6 +322,27 @@ def _generate_relative_positions_embeddings(length, depth, max_relative_position return embeddings +class RelativePositionsEmbeddings(nn.Module): + """ + Given to relative position embedding table, output relative position embeddings + + """ + def __init__(self, depth, max_relative_position): + super(RelativePositionsEmbeddings, self).__init__() + vocab_size = max_relative_position * 2 + 1 + embeddings_table = np.zeros([vocab_size, depth]) + for pos in range(vocab_size): + for i in range(depth // 2): + embeddings_table[pos, 2 * i] = np.sin(pos / np.power(10000, 2 * i / depth)) + embeddings_table[pos, 2 * i + 1] = np.cos(pos / np.power(10000, 2 * i / depth)) + + self.embeddings_table_tensor = nn.Parameter(torch.tensor(embeddings_table, dtype=torch.float)) + + def forward(self, relative_positions): + embeddings = torch.index_select(self.embeddings_table_tensor, 0, relative_positions) + return embeddings + + ### Test: print(_generate_relative_positions_embeddings(6, 32, 4)[0, 0, :]) class NeZhaSelfAttention(nn.Module): @@ -338,10 +359,32 @@ def __init__(self, config): self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.relative_positions_embeddings = _generate_relative_positions_embeddings( - length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).cuda() + + # self.relative_positions_embeddings = _generate_relative_positions_embeddings( + # length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).cuda() + + self.relative_positions_embeddings = RelativePositionsEmbeddings( + depth=self.attention_head_size, max_relative_position=config.max_relative_position) + self.max_relative_position = config.max_relative_position + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + def process_relative_position_embeddings(self, seq_length, device): + """ + Given to seq_length input, generating relative position embeddings + + """ + depth = self.attention_head_size + final_mat = _generate_relative_positions_matrix(seq_length, self.max_relative_position).to(device) + flat_relative_positions_matrix = final_mat.view(-1) + embeddings = self.relative_positions_embeddings(flat_relative_positions_matrix) + my_shape = list(final_mat.size()) + my_shape.append(depth) + embeddings = embeddings.view(my_shape) + relative_position_embeddings = embeddings.clone().detach() + + return relative_position_embeddings + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -363,9 +406,14 @@ def forward(self, hidden_states, attention_mask): attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size() - relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :].to( - device) + # start: generating relative keys + relations_keys = self.process_relative_position(to_seq_length, device) + # end: generating relative keys + + # relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :].to( + # device) # relations_keys = embeddings.clone().detach().to(device) + # query_layer_t = query_layer.permute(2, 0, 1, 3) query_layer_r = query_layer_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, self.attention_head_size) @@ -387,8 +435,12 @@ def forward(self, hidden_states, attention_mask): context_layer = torch.matmul(attention_probs, value_layer) - relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :].to( - device) + # start: generating relative values + relations_values = self.process_relative_position(to_seq_length, device) + # end: generating relative values + + # relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :].to( + # device) attention_probs_t = attention_probs.permute(2, 0, 1, 3) attentions_probs_r = attention_probs_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, to_seq_length)