Skip to content

Commit

Permalink
pointer to reference
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda committed Apr 7, 2017
1 parent 4f22b52 commit 49b397f
Showing 1 changed file with 100 additions and 102 deletions.
202 changes: 100 additions & 102 deletions src/lib/bpe_vocabulary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,18 @@ struct Change {
// vocab: vector bigram frequency.
// stats: sum of bigram frequency.
// indices: index of stats (key=bigram)
void getPairStatistics(vector<pair<vector<string>, int>> * vocab,
void getPairStatistics(
const vector<pair<vector<string>, int>> & vocab,
map<vector<string>, int> * stats,
map<vector<string>, map<unsigned, int>> * indices) {

for (unsigned i = 0; i < vocab->size(); i++) {
const vector<string> word = (*vocab)[i].first;
const int freq = (*vocab)[i].second;
string prev_char = word[0];
for (unsigned i = 0; i < vocab.size(); i++) {
const vector<string> & word = vocab[i].first;
const int freq = vocab[i].second;
for (unsigned j = 1; j < word.size(); j++) {
const string current_char = word[j];
const vector<string> key = {prev_char, current_char};
const vector<string> key = {word[j-1], word[j]};
(*stats)[key] += freq;
(*indices)[key][i] += 1;
prev_char = current_char;
++(*indices)[key][i];
}
}
}
Expand All @@ -57,10 +55,10 @@ void getPairStatistics(vector<pair<vector<string>, int>> * vocab,
//
// Returns:
// most frequent bigram
vector<string> findMax(const map<vector<string>, int> * stats) {
int current_max = -1e5;
vector<string> findMax(const map<vector<string>, int> & stats) {
int current_max = -1000000;
vector<string> current_argmax;
for (auto elm : (*stats)) {
for (const auto & elm : stats) {
if (elm.second > current_max) {
current_max = elm.second;
current_argmax = elm.first;
Expand All @@ -78,30 +76,30 @@ vector<string> findMax(const map<vector<string>, int> * stats) {
//
// Returns:
// vector of replaceable pairs
vector<Change> replacePair(const vector<string> * replace_words,
vector<pair<vector<string>, int>> * vocab,
const map<unsigned, int> * indices) {
string first = (*replace_words)[0];
string second = (*replace_words)[1];
string pair_str = boost::join((*replace_words), "");
vector<Change> replacePair(
const vector<string> & replace_words,
const map<unsigned, int> & indices,
vector<pair<vector<string>, int>> * vocab) {

const string before = replace_words[0] + " " + replace_words[1];
const string after = replace_words[0] + replace_words[1];
vector<Change> changes;

for (const auto index : (*indices)) {
unsigned j = index.first;
int freq = index.second;
if (freq < 1) {
continue;
for (const auto & index : indices) {
if (index.second < 1) {
continue;
}

vector<string> word = (*vocab)[j].first;
freq = (*vocab)[j].second;
string new_word = boost::join(word, " ");
boost::replace_all(new_word, first + " " + second, pair_str);
vector<string> vector_new_word;
const unsigned j = index.first;
const vector<string> & subwords = (*vocab)[j].first;
const int freq = (*vocab)[j].second;
string new_word = boost::join(subwords, " ");
boost::replace_all(new_word, before, after);
vector<string> new_subwords;
boost::split(
vector_new_word, new_word, boost::is_space(), boost::algorithm::token_compress_on);
(*vocab)[j] = pair<vector<string>, int>(vector_new_word, freq);
changes.emplace_back( Change{j, vector_new_word, word, freq} );
new_subwords, new_word, boost::is_space(), boost::algorithm::token_compress_on);
(*vocab)[j] = std::make_pair(new_subwords, freq);
changes.emplace_back( Change {j, new_subwords, subwords, freq} );
}

return changes;
Expand All @@ -110,16 +108,22 @@ vector<Change> replacePair(const vector<string> * replace_words,
// Find index of the specific word from the vector
//
// Arguments:
// word: vector of words
// search_word: search query word
// start_index: start finding from this index
// words: vector of words
// query: search query word
// start: start finding from this index
//
// Returns:
// index of the specific word
int findIndex(vector<string> * word, string * search_word, unsigned start_index) {
auto iter = find(word->begin() + start_index, word->end(), (*search_word));
size_t index = distance(word->begin(), iter);
return index;
int findIndex(
const vector<string> & words, const string & query, const unsigned start) {
for (unsigned i = start; i < words.size(); ++i) {
if (words[i] == query) {
return i;
}
}
NMTKIT_FATAL("query not found.");
//const auto & iter = find(words.begin() + start_index, words.end(), query);
//return static_cast<int>(distance(word->begin(), iter));
}

// Update stats based on changes
Expand All @@ -129,65 +133,63 @@ int findIndex(vector<string> * word, string * search_word, unsigned start_index)
// changes: return value of replacePair()
// stats: sum of bigram frequency.
// indices: index of stats (key=bigram)
void updatePairStatistics(const vector<string> * replace_words,
const vector<Change> * changes,
void updatePairStatistics(
const vector<string> & replace_words,
const vector<Change> & changes,
map<vector<string>, int> * stats,
map<vector<string>, map<unsigned, int>> * indices) {
stats->erase((*replace_words));
indices->erase((*replace_words));
string first = (*replace_words)[0];
string second = (*replace_words)[1];
string new_pair = first + second;

for (unsigned i = 0; i < changes->size(); i++) {
unsigned j = (*changes)[i].index;
vector<string> new_word = (*changes)[i].new_word;
vector<string> old_word = (*changes)[i].old_word;
int freq = (*changes)[i].freq;
stats->erase(replace_words);
indices->erase(replace_words);
const string & first = replace_words[0];
const string & second = replace_words[1];
const string new_pair = first + second;

for (const Change & change : changes) {
unsigned k = 0;
while(true) {
k = findIndex(&old_word, &first, k);
if (k == old_word.size()) {
while (true) {
k = findIndex(change.old_word, first, k);
if (k == change.old_word.size()) {
break;
}
if (k < old_word.size() - 1 and old_word[k+1] == second) {
if (k < change.old_word.size() - 1 and change.old_word[k+1] == second) {
if (k != 0) {
vector<string> prev = {old_word[k-1], old_word[k]};
(*stats)[prev] -= freq;
(*indices)[prev][j] -= 1;
vector<string> prev = {change.old_word[k-1], change.old_word[k]};
(*stats)[prev] -= change.freq;
--(*indices)[prev][change.index];
}
if (k < old_word.size() - 2) {
if (old_word[k+2] != first or k >= old_word.size() - 3 or old_word[k+3] != second) {
vector<string> nex = {old_word[k+1], old_word[k+2]};
(*stats)[nex] -= freq;
(*indices)[nex][j] -= 1;
if (k < change.old_word.size() - 2) {
if (change.old_word[k+2] != first or
k >= change.old_word.size() - 3 or
change.old_word[k+3] != second) {
vector<string> nex = {change.old_word[k+1], change.old_word[k+2]};
(*stats)[nex] -= change.freq;
--(*indices)[nex][change.index];
}
}
k += 2;
}
else {
k += 1;
++k;
}
}

k = 0;
while(true) {
k = findIndex(&new_word, &new_pair, k);
if (k == new_word.size()) {
k = findIndex(change.new_word, new_pair, k);
if (k == change.new_word.size()) {
break;
}
if (k != 0) {
vector<string> prev = {new_word[k-1], new_word[k]};
(*stats)[prev] += freq;
(*indices)[prev][j] += 1;
vector<string> prev = {change.new_word[k-1], change.new_word[k]};
(*stats)[prev] += change.freq;
++(*indices)[prev][change.index];
}
if (k < new_word.size() - 1 and new_word[k+1] != new_pair) {
vector<string> nex = {new_word[k], new_word[k+1]};
(*stats)[nex] += freq;
(*indices)[nex][j] += 1;
if (k < change.new_word.size() - 1 and change.new_word[k+1] != new_pair) {
vector<string> nex = {change.new_word[k], change.new_word[k+1]};
(*stats)[nex] += change.freq;
++(*indices)[nex][change.index];
}
k += 1;
++k;
}
}
}
Expand All @@ -202,20 +204,18 @@ void pruneStats(
map<vector<string>, int> * stats,
map<vector<string>, int> * big_stats,
const int threshold) {
map<vector<string>, int>::iterator it = stats->begin();
auto it = stats->begin();
while (it != stats->end()) {
vector<string> item = it->first;
int freq = it->second;
if (freq < threshold) {
stats->erase(it++);
if (freq < 0) {
(*big_stats)[item] += freq;
}
else {
} else {
(*big_stats)[item] = freq;
}
}
else {
} else {
++it;
}
}
Expand All @@ -228,12 +228,10 @@ void pruneStats(
//
// Returns:
// pairs of character bigram
vector<pair<string, string>> getPairs(vector<string> * word) {
vector<pair<string, string>> getPairs(const vector<string> & word) {
vector<pair<string, string>> pairs;
string prev_char = (*word)[0];
for (unsigned i = 1; i < word->size(); i++) {
pairs.emplace_back(pair<string, string>(prev_char, (*word)[i]));
prev_char = (*word)[i];
for (unsigned i = 1; i < word.size(); ++i) {
pairs.emplace_back(std::make_pair(word[i-1], word[i]));
}
return pairs;
}
Expand All @@ -246,26 +244,26 @@ vector<pair<string, string>> getPairs(vector<string> * word) {
// bpe_cache: BPE converted words
// Returns:
// BPE words
vector<string> encode(const string * orig,
const map<pair<string, string>, unsigned> * bpe_codes,
vector<string> encode(
const string & orig,
const map<pair<string, string>, unsigned> & bpe_codes,
map<string, vector<string>> * bpe_cache) {
// if exists in bpe_cache
const auto &entry = bpe_cache->find(*orig);
const auto & entry = bpe_cache->find(orig);
if (entry != bpe_cache->end()) {
return entry->second;
}

vector<string> word = UTF8::getLetters(*orig);
vector<string> word = UTF8::getLetters(orig);
word.emplace_back("</w>");
vector<pair<string, string>> pairs = getPairs(&word);
vector<pair<string, string>> pairs = getPairs(word);

while (true) {
unsigned min_bigram = UINT_MAX;
unsigned argmin_bigram = 0;
for (unsigned i = 0; i < pairs.size(); i++) {
const auto &entry = bpe_codes->find(pairs[i]);
if (entry != bpe_codes->end() and
entry->second < min_bigram) {
const auto & entry = bpe_codes.find(pairs[i]);
if (entry != bpe_codes.end() and entry->second < min_bigram) {
min_bigram = entry->second;
argmin_bigram = i;
}
Expand All @@ -278,7 +276,7 @@ vector<string> encode(const string * orig,
vector<string> new_word;
unsigned i = 0;
while (i < word.size()) {
unsigned j = findIndex(&word, &first, i);
unsigned j = findIndex(word, first, i);
if (j == word.size()) {
copy(word.begin() + i, word.end(), back_inserter(new_word));
break;
Expand All @@ -298,11 +296,11 @@ vector<string> encode(const string * orig,
if (word.size() == 1) {
break;
} else {
pairs = getPairs(&word);
pairs = getPairs(word);
}
}

(*bpe_cache)[*orig] = word;
(*bpe_cache)[orig] = word;
return word;
}

Expand Down Expand Up @@ -381,20 +379,20 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) {

map<vector<string>, int> stats;
map<vector<string>, map<unsigned, int>> indices;
getPairStatistics(&vector_vocab, &stats, &indices);
getPairStatistics(vector_vocab, &stats, &indices);
map<vector<string>, int> big_stats = stats;
int threshold = stats[findMax(&stats)] / 10;
int threshold = stats[findMax(stats)] / 10;

unsigned num_letter_vocab = stoi_.size();
for (unsigned i = 0; i < size - num_letter_vocab; i++) {
vector<string> most_frequent_index;
if (!stats.empty()) {
most_frequent_index = findMax(&stats);
most_frequent_index = findMax(stats);
}
if (stats.empty() or (i != 0 and stats[most_frequent_index] < threshold)) {
pruneStats(&stats, &big_stats, threshold);
stats = big_stats;
most_frequent_index = findMax(&stats);
most_frequent_index = findMax(stats);
threshold = stats[most_frequent_index] * i/(i+10000.0);
pruneStats(&stats, &big_stats, threshold);
}
Expand All @@ -413,8 +411,8 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) {
}

vector<Change> changes =
replacePair(&most_frequent_index, &vector_vocab, &indices[most_frequent_index]);
updatePairStatistics(&most_frequent_index, &changes, &stats, &indices);
replacePair(most_frequent_index, indices[most_frequent_index], &vector_vocab);
updatePairStatistics(most_frequent_index, changes, &stats, &indices);
stats[most_frequent_index] = 0;

if (i % 100 == 0) {
Expand Down Expand Up @@ -446,7 +444,7 @@ vector<unsigned> BPEVocabulary::convertToIDs(const string & sentence) const {
words, sentence, boost::is_space(), boost::algorithm::token_compress_on);
vector<unsigned> ids;
for (const string & word : words) {
vector<string> new_words = encode(&word, &bpe_codes_, &bpe_cache_);
vector<string> new_words = encode(word, bpe_codes_, &bpe_cache_);
for (const string & new_word : new_words) {
ids.emplace_back(getID(new_word));
}
Expand Down

0 comments on commit 49b397f

Please sign in to comment.