From 71107e21b20ad9baed0aa13c2f67bd5dc7421878 Mon Sep 17 00:00:00 2001 From: WoodySG2018 <42363175+WoodySG2018@users.noreply.github.com> Date: Wed, 3 Jul 2019 17:23:36 +0800 Subject: [PATCH 1/2] multi layer RNN coded multi layer RNN. Using 2 layers as default, can change "num_layers" to adjust number of layers when calling DecoderWithAttention. Please check. --- models.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/models.py b/models.py index a02b246e6..9a8b1bed2 100644 --- a/models.py +++ b/models.py @@ -85,13 +85,19 @@ def forward(self, encoder_out, decoder_hidden): return attention_weighted_encoding, alpha +def LSTMCell(input_size, hidden_size, **kwargs): + m = nn.LSTMCell(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if 'weight' in name or 'bias' in name: + param.data.uniform_(-0.1, 0.1) + return m class DecoderWithAttention(nn.Module): """ Decoder. """ - def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5): + def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5, num_layers = 2): """ :param attention_dim: size of attention network :param embed_dim: embedding size @@ -108,12 +114,13 @@ def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_di self.decoder_dim = decoder_dim self.vocab_size = vocab_size self.dropout = dropout - + self.num_layers = num_layers + self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer self.dropout = nn.Dropout(p=self.dropout) - self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell + self.decode_step = nn.ModuleList([LSTMCell(embed_dim + encoder_dim if layer == 0 else embed_dim, embed_dim) for layer in range(self.num_layers)]) # decoding LSTMCell self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate @@ -154,8 +161,9 @@ def init_hidden_state(self, encoder_out): :return: hidden state, cell state """ mean_encoder_out = encoder_out.mean(dim=1) - h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) - c = self.init_c(mean_encoder_out) + h = [self.init_h(mean_encoder_out) for i in range(self.num_layers)] # (batch_size, decoder_dim) + c = [self.init_c(mean_encoder_out) for i in range(self.num_layers)] + return h, c def forward(self, encoder_out, encoded_captions, caption_lengths): @@ -184,8 +192,8 @@ def forward(self, encoder_out, encoded_captions, caption_lengths): # Embedding embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) - # Initialize LSTM state - h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) + # Initialize LSTM state, initialize cell_vector and hidden_vector + prev_h, prev_c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) # We won't decode at the position, since we've finished generating as soon as we generate # So, decoding lengths are actual lengths - 1 @@ -201,12 +209,22 @@ def forward(self, encoder_out, encoded_captions, caption_lengths): for t in range(max(decode_lengths)): batch_size_t = sum([l > t for l in decode_lengths]) attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], - h[:batch_size_t]) - gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) + prev_h[-1][:batch_size_t]) + gate = self.sigmoid(self.f_beta(prev_h[-1][:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) attention_weighted_encoding = gate * attention_weighted_encoding - h, c = self.decode_step( - torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), - (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) + + input = torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1) + for i, rnn in enumerate(self.decode_step): + # recurrent cell + h, c = rnn(input, (prev_h[i][:batch_size_t], prev_c[i][:batch_size_t])) # cell_vector and hidden_vector + + # hidden state becomes the input to the next layer + input = self.dropout(h) + + # save state for next time step + prev_h[i] = h + prev_c[i] = c + preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) predictions[:batch_size_t, t, :] = preds alphas[:batch_size_t, t, :] = alpha From bbe6dcb16de7ace98d07a0848599903e062bc73a Mon Sep 17 00:00:00 2001 From: "Du, Weina | Andy | RASIA" Date: Mon, 22 Jul 2019 18:28:23 +0800 Subject: [PATCH 2/2] revised caption.py and eval.py to adapt multilayer RNN --- caption.py | 20 ++++++++++++++------ eval.py | 24 ++++++++++++++++-------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/caption.py b/caption.py index 145499014..d3e5d9620 100644 --- a/caption.py +++ b/caption.py @@ -81,16 +81,22 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size= embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) - awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) + awe, alpha = decoder.attention(encoder_out, h[-1]) # (s, encoder_dim), (s, num_pixels) alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) - gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) + gate = decoder.sigmoid(decoder.f_beta(h[-1])) # gating scalar, (s, encoder_dim) awe = gate * awe - h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) + input = torch.cat([embeddings, awe], dim=1) + for j, rnn in enumerate(decoder.decode_step): + #print(input.shape, input) + at_h, at_c = rnn(input, (h[j], c[j])) # (s, decoder_dim) + input = decoder.dropout(at_h) + h[j] = at_h + c[j] = at_c - scores = decoder.fc(h) # (s, vocab_size) + scores = decoder.fc(h[-1]) # (s, vocab_size) scores = F.log_softmax(scores, dim=1) # Add @@ -129,8 +135,10 @@ def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size= break seqs = seqs[incomplete_inds] seqs_alpha = seqs_alpha[incomplete_inds] - h = h[prev_word_inds[incomplete_inds]] - c = c[prev_word_inds[incomplete_inds]] + for j in range(len(h)): + h[j] = h[j][prev_word_inds[incomplete_inds]] + c[j] = c[j][prev_word_inds[incomplete_inds]] + encoder_out = encoder_out[prev_word_inds[incomplete_inds]] top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) diff --git a/eval.py b/eval.py index 3e9359a09..bf9c1b5ca 100644 --- a/eval.py +++ b/eval.py @@ -100,14 +100,21 @@ def evaluate(beam_size): embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) - awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) + awe, _ = decoder.attention(encoder_out, h[-1]) # (s, encoder_dim), (s, num_pixels) - gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) + gate = decoder.sigmoid(decoder.f_beta(h[-1])) # gating scalar, (s, encoder_dim) awe = gate * awe - h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) + input = torch.cat([embeddings, awe], dim=1) + for j, rnn in enumerate(decoder.decode_step): + #print(input.shape, input) + at_h, at_c = rnn(input, (h[j], c[j])) # (s, decoder_dim) + input = decoder.dropout(at_h) + h[j] = at_h + c[j] = at_c + + scores = decoder.fc(h[-1]) # (s, vocab_size) - scores = decoder.fc(h) # (s, vocab_size) scores = F.log_softmax(scores, dim=1) # Add @@ -142,8 +149,9 @@ def evaluate(beam_size): if k == 0: break seqs = seqs[incomplete_inds] - h = h[prev_word_inds[incomplete_inds]] - c = c[prev_word_inds[incomplete_inds]] + for j in range(len(h)): + h[j] = h[j][prev_word_inds[incomplete_inds]] + c[j] = c[j][prev_word_inds[incomplete_inds]] encoder_out = encoder_out[prev_word_inds[incomplete_inds]] top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) @@ -153,8 +161,8 @@ def evaluate(beam_size): break step += 1 - i = complete_seqs_scores.index(max(complete_seqs_scores)) - seq = complete_seqs[i] + j = complete_seqs_scores.index(max(complete_seqs_scores)) + seq = complete_seqs[j] # References img_caps = allcaps[0].tolist()