Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix BLEU score calculation during training #4

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/bin/train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <dynet/tensor.h>
#include <dynet/training.h>
#include <mteval/EvaluatorFactory.h>
#include <mteval/Dictionary.h>
#include <nmtkit/batch_converter.h>
#include <nmtkit/character_vocabulary.h>
#include <nmtkit/encoder_decoder.h>
Expand All @@ -30,6 +31,7 @@
#include <nmtkit/init.h>
#include <nmtkit/monotone_sampler.h>
#include <nmtkit/sorted_random_sampler.h>
#include <nmtkit/test_sampler.h>
#include <nmtkit/word_vocabulary.h>
#include <spdlog/spdlog.h>

Expand Down Expand Up @@ -270,7 +272,7 @@ void saveArchive(
// The log perplexity score.
float evaluateLogPerplexity(
nmtkit::EncoderDecoder & encdec,
nmtkit::MonotoneSampler & sampler,
nmtkit::TestSampler & sampler,
nmtkit::BatchConverter & converter) {
unsigned num_outputs = 0;
float total_loss = 0.0f;
Expand Down Expand Up @@ -302,15 +304,16 @@ float evaluateLogPerplexity(
float evaluateBLEU(
const nmtkit::Vocabulary & trg_vocab,
nmtkit::EncoderDecoder & encdec,
nmtkit::MonotoneSampler & sampler,
nmtkit::TestSampler & sampler,
const unsigned max_length) {
const auto evaluator = MTEval::EvaluatorFactory::create("BLEU");
const unsigned bos_id = trg_vocab.getID("<s>");
const unsigned eos_id = trg_vocab.getID("</s>");
MTEval::Dictionary dict;
MTEval::Statistics stats;
sampler.rewind();
while (sampler.hasSamples()) {
vector<nmtkit::Sample> samples = sampler.getSamples();
vector<nmtkit::TestSample> samples = sampler.getTestSamples();
nmtkit::InferenceGraph ig = encdec.infer(
samples[0].source, bos_id, eos_id, max_length, 1, 0.0f);
const auto hyp_nodes = ig.findOneBestPath(bos_id, eos_id);
Expand All @@ -319,7 +322,10 @@ float evaluateBLEU(
for (unsigned i = 1; i < hyp_nodes.size() - 1; ++i) {
hyp_ids.emplace_back(hyp_nodes[i]->label().word_id);
}
MTEval::Sample eval_sample {hyp_ids, {samples[0].target}};
const string string_hyp = trg_vocab.convertToSentence(hyp_ids);
const MTEval::Sentence sent_hyp = dict.getSentence(string_hyp);
const MTEval::Sentence sent_ref = dict.getSentence(samples[0].target_string);
MTEval::Sample eval_sample {sent_hyp, {sent_ref}};
stats += evaluator->map(eval_sample);
}
return evaluator->integrate(stats);
Expand Down Expand Up @@ -381,10 +387,8 @@ int main(int argc, char * argv[]) {

// Maximum lengths
const unsigned train_max_length = config.get<unsigned>("Batch.max_length");
const unsigned test_max_length = 1024;
const float train_max_length_ratio = config.get<float>(
"Batch.max_length_ratio");
const float test_max_length_ratio = 1e10;

// Creates samplers and batch converter.
nmtkit::SortedRandomSampler train_sampler(
Expand All @@ -398,15 +402,15 @@ int main(int argc, char * argv[]) {
config.get<unsigned>("Global.random_seed"));
const unsigned corpus_size = train_sampler.getNumSamples();
logger->info("Loaded 'train' corpus.");
nmtkit::MonotoneSampler dev_sampler(
nmtkit::TestSampler dev_sampler(
config.get<string>("Corpus.dev_source"),
config.get<string>("Corpus.dev_target"),
*src_vocab, *trg_vocab, test_max_length, test_max_length_ratio, 1);
*src_vocab, *trg_vocab, 1);
logger->info("Loaded 'dev' corpus.");
nmtkit::MonotoneSampler test_sampler(
nmtkit::TestSampler test_sampler(
config.get<string>("Corpus.test_source"),
config.get<string>("Corpus.test_target"),
*src_vocab, *trg_vocab, test_max_length, test_max_length_ratio, 1);
*src_vocab, *trg_vocab, 1);
logger->info("Loaded 'test' corpus.");
const auto fmt_corpus_size = boost::format(
"Cleaned corpus size: train=%d dev=%d test=%d")
Expand Down
2 changes: 2 additions & 0 deletions src/include/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ nobase_include_HEADERS = \
nmtkit/single_text_formatter.h \
nmtkit/softmax_predictor.h \
nmtkit/sorted_random_sampler.h \
nmtkit/test_corpus.h \
nmtkit/test_sampler.h \
nmtkit/vocabulary.h \
nmtkit/word_vocabulary.h
14 changes: 14 additions & 0 deletions src/include/nmtkit/basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ struct Sample {
std::vector<unsigned> target;
};

struct TestSample {
// Source sentence with word IDs
std::vector<unsigned> source;

// Target sentence with word IDs.
std::vector<unsigned> target;

// Source sentence string.
std::string source_string;

// Target sentence string.
std::string target_string;
};

struct Batch {
// Source word ID table with shape (max_source_length, batch_size).
std::vector<std::vector<unsigned>> source_ids;
Expand Down
4 changes: 4 additions & 0 deletions src/include/nmtkit/character_vocabulary.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ class CharacterVocabulary : public Vocabulary {
unsigned getFrequency(const unsigned id) const override;
std::vector<unsigned> convertToIDs(
const std::string & sentence) const override;
std::vector<std::string> convertToTokens(
const std::string & sentence) const override;
std::string convertToSentence(
const std::vector<unsigned> & word_ids) const override;
std::vector<std::string> convertToTokenizedSentence(
const std::vector<unsigned> & word_ids) const override;
unsigned size() const override;

private:
Expand Down
116 changes: 116 additions & 0 deletions src/include/nmtkit/test_corpus.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#ifndef NMTKIT_TEST_CORPUS_H_
#define NMTKIT_TEST_CORPUS_H_

#include <istream>
#include <string>
#include <vector>
#include <nmtkit/basic_types.h>
#include <nmtkit/vocabulary.h>
#include <nmtkit/corpus.h>

namespace nmtkit {

class TestCorpus : public Corpus {
TestCorpus() = delete;
TestCorpus(const TestCorpus &) = delete;
TestCorpus(TestCorpus &&) = delete;
TestCorpus & operator=(const TestCorpus &) = delete;
TestCorpus & operator=(TestCorpus &&) = delete;

public:
// Reads one line from the input stream.
//
// Arguments:
// is: Target input stream.
// line: Placeholder to store new string. Old data will be deleted before
// storing the new value.
//
// Returns:
// true if reading completed successfully, false otherwise (e.g. EOF).
static bool readLine(std::istream * is, std::string * line);

// Reads one line from the input stream.
//
// Arguments:
// vocab: Vocabulary object to be used to convert words into word IDs.
// is: Target input stream.
// word_ids: Placeholder to store new words. Old data will be deleted
// automatically before storing new samples.
// sent_string: Placeholder to store new sentences. Old data will be
// deleted automatically before storing new samples.
//
// Returns:
// true if reading completed successfully, false otherwise (e.g. EOF).
static bool readTokens(
const Vocabulary & vocab,
std::istream * is,
std::vector<unsigned> * word_ids,
std::string * sent_string);

// Loads all samples in the tokenized corpus.
//
// Arguments:
// filepath: Location of the corpus file.
// vocab: Vocabulary object for the corpus language.
// result: Placeholder to store new samples. Old data will be deleted
// automatically before storing new samples.
// string_result: Placeholder to store new samples sentences.
// Old data will be deleted automatically before storing new
// samples.
static void loadSingleSentences(
const std::string & filepath,
const Vocabulary & vocab,
std::vector<std::vector<unsigned>> * result,
std::vector<std::string> * string_result);

// Loads tokenized parallel corpus.
//
// Arguments:
// src_filepath: Location of the source corpus file.
// trg_filepath: Location of the target corpus file.
// src_vocab: Vocabulary object for the source language.
// trg_vocab: Vocabulary object for the target language.
// max_length: Maximum number of words in a sentence. Samples which exceeds
// this value will be skipped.
// max_length_ratio: Maximum ratio of lengths in source/target sentences.
// Samples which exceeds this value will be skipped.
// src_result: Placeholder to store new source samples. Old data will be
// deleted automatically before storing new samples.
// trg_result: Placeholder to store new target samples. Old data will be
// deleted automatically before storing new samples.
// src_string_result: Placeholder to store new source samples strings.
// Old data will be deleted automatically before
// storing new samples.
// trg_string_result: Placeholder to store new target samples strings.
// Old data will be deleted automatically before
// storing new samples.
static void loadParallelSentences(
const std::string & src_filepath,
const std::string & trg_filepath,
const Vocabulary & src_vocab,
const Vocabulary & trg_vocab,
std::vector<std::vector<unsigned>> * src_result,
std::vector<std::vector<unsigned>> * trg_result,
std::vector<std::string> * src_string_result,
std::vector<std::string> * trg_string_result);

// Loads tokenized parallel corpus directly to Sample objects.
//
// Arguments:
// src_filepath: Location of the source corpus file.
// trg_filepath: Location of the target corpus file.
// src_vocab: Vocabulary object for the source language.
// trg_vocab: Vocabulary object for the target language.
// result: Placeholder to store new source/target samples. Old data will be
// deleted automatically before storing new samples.
static void loadParallelSentences(
const std::string & src_filepath,
const std::string & trg_filepath,
const Vocabulary & src_vocab,
const Vocabulary & trg_vocab,
std::vector<TestSample> * result);
};

} // namespace nmtkit

#endif // NMTKIT_TEST_CORPUS_H_
43 changes: 43 additions & 0 deletions src/include/nmtkit/test_sampler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef NMTKIT_TEST_SAMPLER_H_
#define NMTKIT_TEST_SAMPLER_H_

#include <nmtkit/sampler.h>
#include <nmtkit/vocabulary.h>

namespace nmtkit {

class TestSampler : public Sampler {
TestSampler() = delete;
TestSampler(const TestSampler &) = delete;
TestSampler(TestSampler &&) = delete;
TestSampler & operator=(const TestSampler &) = delete;
TestSampler & operator=(TestSampler &&) = delete;

public:
TestSampler(
const std::string & src_filepath,
const std::string & trg_filepath,
const Vocabulary & src_vocab,
const Vocabulary & trg_vocab,
unsigned batch_size);

~TestSampler() override {}

void rewind() override;
std::vector<Sample> getSamples() override;
std::vector<TestSample> getTestSamples();
unsigned getNumSamples() override;
bool hasSamples() const override;

private:
std::vector<std::string> src_samples_string_;
std::vector<std::string> trg_samples_string_;
std::vector<std::vector<unsigned>> src_samples_;
std::vector<std::vector<unsigned>> trg_samples_;
unsigned batch_size_;
unsigned current_;
};

} // namespace nmtkit

#endif // NMTKIT_TEST_SAMPLER_H_
20 changes: 20 additions & 0 deletions src/include/nmtkit/vocabulary.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ class Vocabulary {
// List of word IDs that represents given sentence.
virtual std::vector<unsigned> convertToIDs(
const std::string & sentence) const = 0;

// Converts a sentence into a list of words (tokens).
//
// Arguments:
// sentence: A sentence string.
//
// Returns:
// List of sentence words.
virtual std::vector<std::string> convertToTokens(
const std::string & sentence) const = 0;

// Converts a list of word IDs into a sentence.
//
Expand All @@ -65,6 +75,16 @@ class Vocabulary {
// Generaed sentence string.
virtual std::string convertToSentence(
const std::vector<unsigned> & word_ids) const = 0;

// Converts a list of wordIDs into a list of words (tokens).
//
// Arguments:
// word_ids: A list of word IDs.
//
// Returns:
// Generaed list of sentence words.
virtual std::vector<std::string> convertToTokenizedSentence(
const std::vector<unsigned> & word_ids) const = 0;

// Retrieves the size of the vocabulary.
//
Expand Down
4 changes: 4 additions & 0 deletions src/include/nmtkit/word_vocabulary.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ class WordVocabulary : public Vocabulary {
unsigned getFrequency(const unsigned id) const override;
std::vector<unsigned> convertToIDs(
const std::string & sentence) const override;
std::vector<std::string> convertToTokens(
const std::string & sentence) const override;
std::string convertToSentence(
const std::vector<unsigned> & word_ids) const override;
std::vector<std::string> convertToTokenizedSentence(
const std::vector<unsigned> & word_ids) const override;
unsigned size() const override;

private:
Expand Down
2 changes: 2 additions & 0 deletions src/lib/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ libnmtkit_la_SOURCES = \
single_text_formatter.cc \
softmax_predictor.cc \
sorted_random_sampler.cc \
test_corpus.cc \
test_sampler.cc \
vocabulary.cc \
word_vocabulary.cc

18 changes: 18 additions & 0 deletions src/lib/character_vocabulary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ vector<unsigned> CharacterVocabulary::convertToIDs(
return ids;
}

vector<string> CharacterVocabulary::convertToTokens(
const string & sentence) const {
vector<string> tokens;
for (const string & letter : ::convertToLetters(sentence)) {
tokens.emplace_back(getWord(getID(letter)));
}
return tokens;
}

string CharacterVocabulary::convertToSentence(
const vector<unsigned> & word_ids) const {
vector<string> letters;
Expand All @@ -131,6 +140,15 @@ string CharacterVocabulary::convertToSentence(
return boost::join(letters, "");
}

vector<string> CharacterVocabulary::convertToTokenizedSentence(
const vector<unsigned> & word_ids) const {
vector<string> letters;
for (const unsigned word_id : word_ids) {
letters.emplace_back(getWord(word_id));
}
return letters;
}

unsigned CharacterVocabulary::size() const {
return itos_.size();
}
Expand Down
Loading