Skip to content

Commit

Permalink
implemented training switch in some functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda committed Apr 20, 2017
1 parent 09bad4a commit 8f779c1
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 117 deletions.
7 changes: 2 additions & 5 deletions bin/train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ float evaluateLogPerplexity(
const nmtkit::Batch batch = converter.convert(samples);
dynet::ComputationGraph cg;
dynet::expr::Expression total_loss_expr = encdec.buildTrainGraph(
batch, 0.0, &cg);
batch, &cg, false);
num_outputs += batch.target_ids.size() - 1;
total_loss += static_cast<float>(
dynet::as_scalar(cg.forward(total_loss_expr)));
Expand Down Expand Up @@ -514,10 +514,7 @@ int main(int argc, char * argv[]) {
// Decaying factors
float lr_decay = 1.0f;
const float lr_decay_ratio = config.get<float>("Train.lr_decay_ratio");

const float dropout_ratio = config.get<float>("Train.dropout_ratio");
const unsigned max_iteration = config.get<unsigned>("Train.max_iteration");

const string eval_type = config.get<string>("Train.evaluation_type");
const unsigned eval_interval = config.get<unsigned>(
"Train.evaluation_interval");
Expand Down Expand Up @@ -551,7 +548,7 @@ int main(int argc, char * argv[]) {
const nmtkit::Batch batch = batch_converter.convert(samples);
dynet::ComputationGraph cg;
dynet::expr::Expression total_loss_expr = encdec.buildTrainGraph(
batch, dropout_ratio, &cg);
batch, &cg, true);
cg.forward(total_loss_expr);
cg.backward(total_loss_expr);
trainer->update(lr_decay);
Expand Down
19 changes: 11 additions & 8 deletions nmtkit/backward_encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,33 @@ namespace DE = dynet::expr;
namespace nmtkit {

BackwardEncoder::BackwardEncoder(
unsigned num_layers,
unsigned vocab_size,
unsigned embed_size,
unsigned hidden_size,
const unsigned num_layers,
const unsigned vocab_size,
const unsigned embed_size,
const unsigned hidden_size,
const float dropout_rate,
dynet::Model * model)
: num_layers_(num_layers)
, vocab_size_(vocab_size)
, embed_size_(embed_size)
, hidden_size_(hidden_size)
, dropout_rate_(dropout_rate)
, rnn_(num_layers, embed_size, hidden_size, *model) {
p_lookup_ = model->add_lookup_parameters(vocab_size_, {embed_size_});
}

void BackwardEncoder::prepare(
const float dropout_ratio,
dynet::ComputationGraph * cg) {
rnn_.set_dropout(dropout_ratio);
dynet::ComputationGraph * cg,
const bool is_training) {
rnn_.set_dropout(is_training ? dropout_rate_ : 0.0f);
rnn_.new_graph(*cg);
rnn_.start_new_sequence();
}

vector<DE::Expression> BackwardEncoder::compute(
const vector<vector<unsigned>> & input_ids,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool /* is_training */) {
vector<DE::Expression> outputs;
for (auto it = input_ids.rbegin(); it != input_ids.rend(); ++it) {
const DE::Expression embed = DE::lookup(*cg, p_lookup_, *it);
Expand Down
30 changes: 17 additions & 13 deletions nmtkit/bahdanau_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ namespace DE = dynet::expr;
namespace nmtkit {

BahdanauDecoder::BahdanauDecoder(
unsigned num_layers,
unsigned vocab_size,
unsigned in_embed_size,
unsigned out_embed_size,
unsigned hidden_size,
unsigned seed_size,
unsigned context_size,
const unsigned num_layers,
const unsigned vocab_size,
const unsigned in_embed_size,
const unsigned out_embed_size,
const unsigned hidden_size,
const unsigned seed_size,
const unsigned context_size,
const float dropout_rate,
dynet::Model * model)
: num_layers_(num_layers)
, vocab_size_(vocab_size)
Expand All @@ -25,6 +26,7 @@ BahdanauDecoder::BahdanauDecoder(
, hidden_size_(hidden_size)
, seed_size_(seed_size)
, context_size_(context_size)
, dropout_rate_(dropout_rate)
, dec2out_({in_embed_size + context_size + hidden_size, out_embed_size}, model)
, rnn_(num_layers, in_embed_size + context_size, hidden_size, *model)
, p_lookup_(model->add_lookup_parameters(vocab_size, {in_embed_size}))
Expand All @@ -37,8 +39,8 @@ BahdanauDecoder::BahdanauDecoder(

Decoder::State BahdanauDecoder::prepare(
const vector<DE::Expression> & seed,
const float dropout_ratio,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool is_training) {
NMTKIT_CHECK_EQ(2 * num_layers_, seed.size(), "Invalid number of initial states.");
vector<DE::Expression> states;
for (unsigned i = 0; i < num_layers_; ++i) {
Expand All @@ -48,7 +50,7 @@ Decoder::State BahdanauDecoder::prepare(
for (unsigned i = 0; i < num_layers_; ++i) {
states.emplace_back(DE::tanh(states[i]));
}
rnn_.set_dropout(dropout_ratio);
rnn_.set_dropout(is_training ? dropout_rate_ : 0.0f);
rnn_.new_graph(*cg);
rnn_.start_new_sequence(states);
dec2out_.prepare(cg);
Expand All @@ -59,9 +61,10 @@ Decoder::State BahdanauDecoder::oneStep(
const Decoder::State & state,
const vector<unsigned> & input_ids,
Attention * attention,
dynet::ComputationGraph * cg,
dynet::expr::Expression * atten_probs,
dynet::expr::Expression * output) {
dynet::expr::Expression * output,
dynet::ComputationGraph * cg,
const bool is_training) {
NMTKIT_CHECK_EQ(
1, state.positions.size(), "Invalid number of RNN positions.");
NMTKIT_CHECK_EQ(
Expand All @@ -73,7 +76,8 @@ Decoder::State BahdanauDecoder::oneStep(

// Calculation
const DE::Expression in_embed = DE::lookup(*cg, p_lookup_, input_ids);
const vector<DE::Expression> atten_info = attention->compute(prev_h);
const vector<DE::Expression> atten_info = attention->compute(
prev_h, is_training);
const DE::Expression next_h = rnn_.add_input(
prev_pos, DE::concatenate({in_embed, atten_info[1]}));
// Note: In the original implementation, the MaxOut function is used for the
Expand Down
22 changes: 13 additions & 9 deletions nmtkit/bidirectional_encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,28 @@ namespace DE = dynet::expr;
namespace nmtkit {

BidirectionalEncoder::BidirectionalEncoder(
unsigned num_layers,
unsigned vocab_size,
unsigned embed_size,
unsigned hidden_size,
const unsigned num_layers,
const unsigned vocab_size,
const unsigned embed_size,
const unsigned hidden_size,
const float dropout_rate,
dynet::Model * model)
: num_layers_(num_layers)
, vocab_size_(vocab_size)
, embed_size_(embed_size)
, hidden_size_(hidden_size)
, dropout_rate_(dropout_rate)
, rnn_fw_(num_layers, embed_size, hidden_size, *model)
, rnn_bw_(num_layers, embed_size, hidden_size, *model) {
p_lookup_ = model->add_lookup_parameters(vocab_size_, {embed_size_});
}

void BidirectionalEncoder::prepare(
const float dropout_ratio,
dynet::ComputationGraph * cg) {
rnn_fw_.set_dropout(dropout_ratio);
rnn_bw_.set_dropout(dropout_ratio);
dynet::ComputationGraph * cg,
const bool is_training) {
const float dr = is_training ? dropout_rate_ : 0.0f;
rnn_fw_.set_dropout(dr);
rnn_bw_.set_dropout(dr);
rnn_fw_.new_graph(*cg);
rnn_bw_.new_graph(*cg);
rnn_fw_.start_new_sequence();
Expand All @@ -38,7 +41,8 @@ void BidirectionalEncoder::prepare(

vector<DE::Expression> BidirectionalEncoder::compute(
const vector<vector<unsigned>> & input_ids,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool /* is_training */) {
const int input_len = input_ids.size();

// Embedding lookup
Expand Down
10 changes: 6 additions & 4 deletions nmtkit/bilinear_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace DE = dynet::expr;
namespace nmtkit {

BilinearAttention::BilinearAttention(
unsigned memory_size,
unsigned controller_size,
const unsigned memory_size,
const unsigned controller_size,
dynet::Model * model) {
NMTKIT_CHECK(
memory_size > 0, "memory_size should be greater than 0.");
Expand All @@ -25,7 +25,8 @@ BilinearAttention::BilinearAttention(

void BilinearAttention::prepare(
const vector<DE::Expression> & memories,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool /* is_training */) {
// Concatenated memory matrix.
// Shape: {memory_size, seq_length}
i_concat_mem_ = DE::concatenate_cols(memories);
Expand All @@ -40,7 +41,8 @@ void BilinearAttention::prepare(
}

vector<DE::Expression> BilinearAttention::compute(
const DE::Expression & controller) {
const DE::Expression & controller,
const bool /* is_training */) {
// Computes attention.
// Shape: {seq_length, 1}
DE::Expression atten_probs_inner = DE::softmax(i_converted_mem_ * controller);
Expand Down
4 changes: 2 additions & 2 deletions nmtkit/bilinear_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class BilinearAttention : public Attention {
// controller_size: Number of units in the controller input.
// model: Model object for training.
BilinearAttention(
unsigned memory_size,
unsigned controller_size,
const unsigned memory_size,
const unsigned controller_size,
dynet::Model * model);

~BilinearAttention() override {}
Expand Down
28 changes: 16 additions & 12 deletions nmtkit/default_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ namespace DE = dynet::expr;
namespace nmtkit {

DefaultDecoder::DefaultDecoder(
unsigned num_layers,
unsigned vocab_size,
unsigned embed_size,
unsigned hidden_size,
unsigned seed_size,
unsigned context_size,
const unsigned num_layers,
const unsigned vocab_size,
const unsigned embed_size,
const unsigned hidden_size,
const unsigned seed_size,
const unsigned context_size,
const float dropout_rate,
dynet::Model * model)
: num_layers_(num_layers)
, vocab_size_(vocab_size)
, embed_size_(embed_size)
, hidden_size_(hidden_size)
, seed_size_(seed_size)
, context_size_(context_size)
, dropout_rate_(dropout_rate)
, rnn_(num_layers, embed_size + context_size, hidden_size, *model)
, p_lookup_(model->add_lookup_parameters(vocab_size, {embed_size}))
{
Expand All @@ -34,8 +36,8 @@ DefaultDecoder::DefaultDecoder(

Decoder::State DefaultDecoder::prepare(
const vector<DE::Expression> & seed,
const float dropout_ratio,
dynet::ComputationGraph * cg) {
dynet::ComputationGraph * cg,
const bool is_training) {
NMTKIT_CHECK_EQ(2 * num_layers_, seed.size(), "Invalid number of initial states.");
vector<DE::Expression> states;
for (unsigned i = 0; i < num_layers_; ++i) {
Expand All @@ -45,7 +47,7 @@ Decoder::State DefaultDecoder::prepare(
for (unsigned i = 0; i < num_layers_; ++i) {
states.emplace_back(DE::tanh(states[i]));
}
rnn_.set_dropout(dropout_ratio);
rnn_.set_dropout(is_training ? dropout_rate_ : 0.0f);
rnn_.new_graph(*cg);
rnn_.start_new_sequence(states);
return {{rnn_.state()}, {rnn_.back()}};
Expand All @@ -55,9 +57,10 @@ Decoder::State DefaultDecoder::oneStep(
const Decoder::State & state,
const vector<unsigned> & input_ids,
Attention * attention,
dynet::ComputationGraph * cg,
dynet::expr::Expression * atten_probs,
dynet::expr::Expression * output) {
dynet::expr::Expression * output,
dynet::ComputationGraph * cg,
const bool is_training) {
NMTKIT_CHECK_EQ(
1, state.positions.size(), "Invalid number of RNN positions.");
NMTKIT_CHECK_EQ(
Expand All @@ -69,7 +72,8 @@ Decoder::State DefaultDecoder::oneStep(

// Calculation
const DE::Expression embed = DE::lookup(*cg, p_lookup_, input_ids);
const vector<DE::Expression> atten_info = attention->compute(prev_h);
const vector<DE::Expression> atten_info = attention->compute(
prev_h, is_training);
const DE::Expression next_h = rnn_.add_input(
prev_pos, DE::concatenate({embed, atten_info[1]}));

Expand Down
Loading

0 comments on commit 8f779c1

Please sign in to comment.