diff --git a/src/lib/bpe_vocabulary.cc b/src/lib/bpe_vocabulary.cc index b16da93..1b7ea45 100644 --- a/src/lib/bpe_vocabulary.cc +++ b/src/lib/bpe_vocabulary.cc @@ -32,20 +32,18 @@ struct Change { // vocab: vector bigram frequency. // stats: sum of bigram frequency. // indices: index of stats (key=bigram) -void getPairStatistics(vector, int>> * vocab, +void getPairStatistics( + const vector, int>> & vocab, map, int> * stats, map, map> * indices) { - for (unsigned i = 0; i < vocab->size(); i++) { - const vector word = (*vocab)[i].first; - const int freq = (*vocab)[i].second; - string prev_char = word[0]; + for (unsigned i = 0; i < vocab.size(); i++) { + const vector & word = vocab[i].first; + const int freq = vocab[i].second; for (unsigned j = 1; j < word.size(); j++) { - const string current_char = word[j]; - const vector key = {prev_char, current_char}; + const vector key = {word[j-1], word[j]}; (*stats)[key] += freq; - (*indices)[key][i] += 1; - prev_char = current_char; + ++(*indices)[key][i]; } } } @@ -57,10 +55,10 @@ void getPairStatistics(vector, int>> * vocab, // // Returns: // most frequent bigram -vector findMax(const map, int> * stats) { - int current_max = -1e5; +vector findMax(const map, int> & stats) { + int current_max = -1000000; vector current_argmax; - for (auto elm : (*stats)) { + for (const auto & elm : stats) { if (elm.second > current_max) { current_max = elm.second; current_argmax = elm.first; @@ -78,30 +76,30 @@ vector findMax(const map, int> * stats) { // // Returns: // vector of replaceable pairs -vector replacePair(const vector * replace_words, - vector, int>> * vocab, - const map * indices) { - string first = (*replace_words)[0]; - string second = (*replace_words)[1]; - string pair_str = boost::join((*replace_words), ""); +vector replacePair( + const vector & replace_words, + const map & indices, + vector, int>> * vocab) { + + const string before = replace_words[0] + " " + replace_words[1]; + const string after = replace_words[0] + replace_words[1]; vector changes; - for (const auto index : (*indices)) { - unsigned j = index.first; - int freq = index.second; - if (freq < 1) { - continue; + for (const auto & index : indices) { + if (index.second < 1) { + continue; } - vector word = (*vocab)[j].first; - freq = (*vocab)[j].second; - string new_word = boost::join(word, " "); - boost::replace_all(new_word, first + " " + second, pair_str); - vector vector_new_word; + const unsigned j = index.first; + const vector & subwords = (*vocab)[j].first; + const int freq = (*vocab)[j].second; + string new_word = boost::join(subwords, " "); + boost::replace_all(new_word, before, after); + vector new_subwords; boost::split( - vector_new_word, new_word, boost::is_space(), boost::algorithm::token_compress_on); - (*vocab)[j] = pair, int>(vector_new_word, freq); - changes.emplace_back( Change{j, vector_new_word, word, freq} ); + new_subwords, new_word, boost::is_space(), boost::algorithm::token_compress_on); + (*vocab)[j] = std::make_pair(new_subwords, freq); + changes.emplace_back( Change {j, new_subwords, subwords, freq} ); } return changes; @@ -110,16 +108,22 @@ vector replacePair(const vector * replace_words, // Find index of the specific word from the vector // // Arguments: -// word: vector of words -// search_word: search query word -// start_index: start finding from this index +// words: vector of words +// query: search query word +// start: start finding from this index // // Returns: // index of the specific word -int findIndex(vector * word, string * search_word, unsigned start_index) { - auto iter = find(word->begin() + start_index, word->end(), (*search_word)); - size_t index = distance(word->begin(), iter); - return index; +int findIndex( + const vector & words, const string & query, const unsigned start) { + for (unsigned i = start; i < words.size(); ++i) { + if (words[i] == query) { + return i; + } + } + NMTKIT_FATAL("query not found."); + //const auto & iter = find(words.begin() + start_index, words.end(), query); + //return static_cast(distance(word->begin(), iter)); } // Update stats based on changes @@ -129,65 +133,63 @@ int findIndex(vector * word, string * search_word, unsigned start_index) // changes: return value of replacePair() // stats: sum of bigram frequency. // indices: index of stats (key=bigram) -void updatePairStatistics(const vector * replace_words, - const vector * changes, +void updatePairStatistics( + const vector & replace_words, + const vector & changes, map, int> * stats, map, map> * indices) { - stats->erase((*replace_words)); - indices->erase((*replace_words)); - string first = (*replace_words)[0]; - string second = (*replace_words)[1]; - string new_pair = first + second; - - for (unsigned i = 0; i < changes->size(); i++) { - unsigned j = (*changes)[i].index; - vector new_word = (*changes)[i].new_word; - vector old_word = (*changes)[i].old_word; - int freq = (*changes)[i].freq; + stats->erase(replace_words); + indices->erase(replace_words); + const string & first = replace_words[0]; + const string & second = replace_words[1]; + const string new_pair = first + second; + for (const Change & change : changes) { unsigned k = 0; - while(true) { - k = findIndex(&old_word, &first, k); - if (k == old_word.size()) { + while (true) { + k = findIndex(change.old_word, first, k); + if (k == change.old_word.size()) { break; } - if (k < old_word.size() - 1 and old_word[k+1] == second) { + if (k < change.old_word.size() - 1 and change.old_word[k+1] == second) { if (k != 0) { - vector prev = {old_word[k-1], old_word[k]}; - (*stats)[prev] -= freq; - (*indices)[prev][j] -= 1; + vector prev = {change.old_word[k-1], change.old_word[k]}; + (*stats)[prev] -= change.freq; + --(*indices)[prev][change.index]; } - if (k < old_word.size() - 2) { - if (old_word[k+2] != first or k >= old_word.size() - 3 or old_word[k+3] != second) { - vector nex = {old_word[k+1], old_word[k+2]}; - (*stats)[nex] -= freq; - (*indices)[nex][j] -= 1; + if (k < change.old_word.size() - 2) { + if (change.old_word[k+2] != first or + k >= change.old_word.size() - 3 or + change.old_word[k+3] != second) { + vector nex = {change.old_word[k+1], change.old_word[k+2]}; + (*stats)[nex] -= change.freq; + --(*indices)[nex][change.index]; } } k += 2; } else { - k += 1; + ++k; } } k = 0; while(true) { - k = findIndex(&new_word, &new_pair, k); - if (k == new_word.size()) { + k = findIndex(change.new_word, new_pair, k); + if (k == change.new_word.size()) { break; } if (k != 0) { - vector prev = {new_word[k-1], new_word[k]}; - (*stats)[prev] += freq; - (*indices)[prev][j] += 1; + vector prev = {change.new_word[k-1], change.new_word[k]}; + (*stats)[prev] += change.freq; + ++(*indices)[prev][change.index]; } - if (k < new_word.size() - 1 and new_word[k+1] != new_pair) { - vector nex = {new_word[k], new_word[k+1]}; - (*stats)[nex] += freq; - (*indices)[nex][j] += 1; + if (k < change.new_word.size() - 1 and change.new_word[k+1] != new_pair) { + vector nex = {change.new_word[k], change.new_word[k+1]}; + (*stats)[nex] += change.freq; + ++(*indices)[nex][change.index]; } - k += 1; + ++k; } } } @@ -202,7 +204,7 @@ void pruneStats( map, int> * stats, map, int> * big_stats, const int threshold) { - map, int>::iterator it = stats->begin(); + auto it = stats->begin(); while (it != stats->end()) { vector item = it->first; int freq = it->second; @@ -210,12 +212,10 @@ void pruneStats( stats->erase(it++); if (freq < 0) { (*big_stats)[item] += freq; - } - else { + } else { (*big_stats)[item] = freq; } - } - else { + } else { ++it; } } @@ -228,12 +228,10 @@ void pruneStats( // // Returns: // pairs of character bigram -vector> getPairs(vector * word) { +vector> getPairs(const vector & word) { vector> pairs; - string prev_char = (*word)[0]; - for (unsigned i = 1; i < word->size(); i++) { - pairs.emplace_back(pair(prev_char, (*word)[i])); - prev_char = (*word)[i]; + for (unsigned i = 1; i < word.size(); ++i) { + pairs.emplace_back(std::make_pair(word[i-1], word[i])); } return pairs; } @@ -246,26 +244,26 @@ vector> getPairs(vector * word) { // bpe_cache: BPE converted words // Returns: // BPE words -vector encode(const string * orig, - const map, unsigned> * bpe_codes, +vector encode( + const string & orig, + const map, unsigned> & bpe_codes, map> * bpe_cache) { // if exists in bpe_cache - const auto &entry = bpe_cache->find(*orig); + const auto & entry = bpe_cache->find(orig); if (entry != bpe_cache->end()) { return entry->second; } - vector word = UTF8::getLetters(*orig); + vector word = UTF8::getLetters(orig); word.emplace_back(""); - vector> pairs = getPairs(&word); + vector> pairs = getPairs(word); while (true) { unsigned min_bigram = UINT_MAX; unsigned argmin_bigram = 0; for (unsigned i = 0; i < pairs.size(); i++) { - const auto &entry = bpe_codes->find(pairs[i]); - if (entry != bpe_codes->end() and - entry->second < min_bigram) { + const auto & entry = bpe_codes.find(pairs[i]); + if (entry != bpe_codes.end() and entry->second < min_bigram) { min_bigram = entry->second; argmin_bigram = i; } @@ -278,7 +276,7 @@ vector encode(const string * orig, vector new_word; unsigned i = 0; while (i < word.size()) { - unsigned j = findIndex(&word, &first, i); + unsigned j = findIndex(word, first, i); if (j == word.size()) { copy(word.begin() + i, word.end(), back_inserter(new_word)); break; @@ -298,11 +296,11 @@ vector encode(const string * orig, if (word.size() == 1) { break; } else { - pairs = getPairs(&word); + pairs = getPairs(word); } } - (*bpe_cache)[*orig] = word; + (*bpe_cache)[orig] = word; return word; } @@ -381,20 +379,20 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) { map, int> stats; map, map> indices; - getPairStatistics(&vector_vocab, &stats, &indices); + getPairStatistics(vector_vocab, &stats, &indices); map, int> big_stats = stats; - int threshold = stats[findMax(&stats)] / 10; + int threshold = stats[findMax(stats)] / 10; unsigned num_letter_vocab = stoi_.size(); for (unsigned i = 0; i < size - num_letter_vocab; i++) { vector most_frequent_index; if (!stats.empty()) { - most_frequent_index = findMax(&stats); + most_frequent_index = findMax(stats); } if (stats.empty() or (i != 0 and stats[most_frequent_index] < threshold)) { pruneStats(&stats, &big_stats, threshold); stats = big_stats; - most_frequent_index = findMax(&stats); + most_frequent_index = findMax(stats); threshold = stats[most_frequent_index] * i/(i+10000.0); pruneStats(&stats, &big_stats, threshold); } @@ -413,8 +411,8 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) { } vector changes = - replacePair(&most_frequent_index, &vector_vocab, &indices[most_frequent_index]); - updatePairStatistics(&most_frequent_index, &changes, &stats, &indices); + replacePair(most_frequent_index, indices[most_frequent_index], &vector_vocab); + updatePairStatistics(most_frequent_index, changes, &stats, &indices); stats[most_frequent_index] = 0; if (i % 100 == 0) { @@ -446,7 +444,7 @@ vector BPEVocabulary::convertToIDs(const string & sentence) const { words, sentence, boost::is_space(), boost::algorithm::token_compress_on); vector ids; for (const string & word : words) { - vector new_words = encode(&word, &bpe_codes_, &bpe_cache_); + vector new_words = encode(word, bpe_codes_, &bpe_cache_); for (const string & new_word : new_words) { ids.emplace_back(getID(new_word)); }