Skip to content

Commit 4f22b52

Browse files
authored
Merge pull request #8 from MorinoseiMorizo/master
Fix bugs on BPEVocabulary
2 parents 8f132b4 + e9e3ddb commit 4f22b52

File tree

1 file changed

+80
-78
lines changed

1 file changed

+80
-78
lines changed

src/lib/bpe_vocabulary.cc

+80-78
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,19 @@ struct Change {
3232
// vocab: vector bigram frequency.
3333
// stats: sum of bigram frequency.
3434
// indices: index of stats (key=bigram)
35-
void getPairStatistics(vector<pair<vector<string>, int>> vocab,
36-
map<vector<string>, int> & stats,
37-
map<vector<string>, map<unsigned, int>> & indices) {
35+
void getPairStatistics(vector<pair<vector<string>, int>> * vocab,
36+
map<vector<string>, int> * stats,
37+
map<vector<string>, map<unsigned, int>> * indices) {
3838

39-
for (unsigned i = 0; i < vocab.size(); i++) {
40-
const vector<string> word = vocab[i].first;
41-
const int freq = vocab[i].second;
39+
for (unsigned i = 0; i < vocab->size(); i++) {
40+
const vector<string> word = (*vocab)[i].first;
41+
const int freq = (*vocab)[i].second;
4242
string prev_char = word[0];
4343
for (unsigned j = 1; j < word.size(); j++) {
4444
const string current_char = word[j];
4545
const vector<string> key = {prev_char, current_char};
46-
stats[key] += freq;
47-
indices[key][i] += 1;
46+
(*stats)[key] += freq;
47+
(*indices)[key][i] += 1;
4848
prev_char = current_char;
4949
}
5050
}
@@ -57,10 +57,10 @@ void getPairStatistics(vector<pair<vector<string>, int>> vocab,
5757
//
5858
// Returns:
5959
// most frequent bigram
60-
vector<string> findMax(const map<vector<string>, int> stats) {
60+
vector<string> findMax(const map<vector<string>, int> * stats) {
6161
int current_max = -1e5;
6262
vector<string> current_argmax;
63-
for (auto elm : stats) {
63+
for (auto elm : (*stats)) {
6464
if (elm.second > current_max) {
6565
current_max = elm.second;
6666
current_argmax = elm.first;
@@ -78,29 +78,29 @@ vector<string> findMax(const map<vector<string>, int> stats) {
7878
//
7979
// Returns:
8080
// vector of replaceable pairs
81-
vector<Change> replacePair(const vector<string> replace_words,
82-
vector<pair<vector<string>, int>> & vocab,
83-
const map<unsigned, int> indices) {
84-
string first = replace_words[0];
85-
string second = replace_words[1];
86-
string pair_str = boost::join(replace_words, "");
81+
vector<Change> replacePair(const vector<string> * replace_words,
82+
vector<pair<vector<string>, int>> * vocab,
83+
const map<unsigned, int> * indices) {
84+
string first = (*replace_words)[0];
85+
string second = (*replace_words)[1];
86+
string pair_str = boost::join((*replace_words), "");
8787
vector<Change> changes;
8888

89-
for (const auto index : indices) {
89+
for (const auto index : (*indices)) {
9090
unsigned j = index.first;
9191
int freq = index.second;
9292
if (freq < 1) {
9393
continue;
9494
}
9595

96-
vector<string> word = vocab[j].first;
97-
freq = vocab[j].second;
96+
vector<string> word = (*vocab)[j].first;
97+
freq = (*vocab)[j].second;
9898
string new_word = boost::join(word, " ");
9999
boost::replace_all(new_word, first + " " + second, pair_str);
100100
vector<string> vector_new_word;
101101
boost::split(
102102
vector_new_word, new_word, boost::is_space(), boost::algorithm::token_compress_on);
103-
vocab[j] = pair<vector<string>, int>(vector_new_word, freq);
103+
(*vocab)[j] = pair<vector<string>, int>(vector_new_word, freq);
104104
changes.emplace_back( Change{j, vector_new_word, word, freq} );
105105
}
106106

@@ -116,9 +116,9 @@ vector<Change> replacePair(const vector<string> replace_words,
116116
//
117117
// Returns:
118118
// index of the specific word
119-
int findIndex(vector<string> word, string search_word, unsigned start_index) {
120-
auto iter = find(word.begin() + start_index, word.end(), search_word);
121-
size_t index = distance(word.begin(), iter);
119+
int findIndex(vector<string> * word, string * search_word, unsigned start_index) {
120+
auto iter = find(word->begin() + start_index, word->end(), (*search_word));
121+
size_t index = distance(word->begin(), iter);
122122
return index;
123123
}
124124

@@ -129,39 +129,39 @@ int findIndex(vector<string> word, string search_word, unsigned start_index) {
129129
// changes: return value of replacePair()
130130
// stats: sum of bigram frequency.
131131
// indices: index of stats (key=bigram)
132-
void updatePairStatistics(const vector<string> replace_words,
133-
const vector<Change> changes,
134-
map<vector<string>, int> & stats,
135-
map<vector<string>, map<unsigned, int>> & indices) {
136-
stats.erase(replace_words);
137-
indices.erase(replace_words);
138-
string first = replace_words[0];
139-
string second = replace_words[1];
132+
void updatePairStatistics(const vector<string> * replace_words,
133+
const vector<Change> * changes,
134+
map<vector<string>, int> * stats,
135+
map<vector<string>, map<unsigned, int>> * indices) {
136+
stats->erase((*replace_words));
137+
indices->erase((*replace_words));
138+
string first = (*replace_words)[0];
139+
string second = (*replace_words)[1];
140140
string new_pair = first + second;
141141

142-
for (unsigned i = 0; i < changes.size(); i++) {
143-
unsigned j = changes[i].index;
144-
vector<string> new_word = changes[i].new_word;
145-
vector<string> old_word = changes[i].old_word;
146-
int freq = changes[i].freq;
142+
for (unsigned i = 0; i < changes->size(); i++) {
143+
unsigned j = (*changes)[i].index;
144+
vector<string> new_word = (*changes)[i].new_word;
145+
vector<string> old_word = (*changes)[i].old_word;
146+
int freq = (*changes)[i].freq;
147147

148148
unsigned k = 0;
149149
while(true) {
150-
k = findIndex(old_word, first, k);
150+
k = findIndex(&old_word, &first, k);
151151
if (k == old_word.size()) {
152152
break;
153153
}
154154
if (k < old_word.size() - 1 and old_word[k+1] == second) {
155155
if (k != 0) {
156156
vector<string> prev = {old_word[k-1], old_word[k]};
157-
stats[prev] -= freq;
158-
indices[prev][j] -= 1;
157+
(*stats)[prev] -= freq;
158+
(*indices)[prev][j] -= 1;
159159
}
160160
if (k < old_word.size() - 2) {
161161
if (old_word[k+2] != first or k >= old_word.size() - 3 or old_word[k+3] != second) {
162162
vector<string> nex = {old_word[k+1], old_word[k+2]};
163-
stats[nex] -= freq;
164-
indices[nex][j] -= 1;
163+
(*stats)[nex] -= freq;
164+
(*indices)[nex][j] -= 1;
165165
}
166166
}
167167
k += 2;
@@ -173,19 +173,19 @@ void updatePairStatistics(const vector<string> replace_words,
173173

174174
k = 0;
175175
while(true) {
176-
k = findIndex(new_word, new_pair, k);
176+
k = findIndex(&new_word, &new_pair, k);
177177
if (k == new_word.size()) {
178178
break;
179179
}
180180
if (k != 0) {
181181
vector<string> prev = {new_word[k-1], new_word[k]};
182-
stats[prev] += freq;
183-
indices[prev][j] += 1;
182+
(*stats)[prev] += freq;
183+
(*indices)[prev][j] += 1;
184184
}
185185
if (k < new_word.size() - 1 and new_word[k+1] != new_pair) {
186186
vector<string> nex = {new_word[k], new_word[k+1]};
187-
stats[nex] += freq;
188-
indices[nex][j] += 1;
187+
(*stats)[nex] += freq;
188+
(*indices)[nex][j] += 1;
189189
}
190190
k += 1;
191191
}
@@ -199,20 +199,20 @@ void updatePairStatistics(const vector<string> replace_words,
199199
// big_stats: sum of bigram frequency (not pruned).
200200
// threshold: words that frequency is less than this threshold will be pruned
201201
void pruneStats(
202-
map<vector<string>, int> & stats,
203-
map<vector<string>, int> & big_stats,
202+
map<vector<string>, int> * stats,
203+
map<vector<string>, int> * big_stats,
204204
const int threshold) {
205-
map<vector<string>, int>::iterator it = stats.begin();
206-
while (it != stats.end()) {
205+
map<vector<string>, int>::iterator it = stats->begin();
206+
while (it != stats->end()) {
207207
vector<string> item = it->first;
208208
int freq = it->second;
209209
if (freq < threshold) {
210-
stats.erase(it++);
210+
stats->erase(it++);
211211
if (freq < 0) {
212-
big_stats[item] += freq;
212+
(*big_stats)[item] += freq;
213213
}
214214
else {
215-
big_stats[item] = freq;
215+
(*big_stats)[item] = freq;
216216
}
217217
}
218218
else {
@@ -228,12 +228,12 @@ void pruneStats(
228228
//
229229
// Returns:
230230
// pairs of character bigram
231-
vector<pair<string, string>> getPairs(vector<string> word) {
231+
vector<pair<string, string>> getPairs(vector<string> * word) {
232232
vector<pair<string, string>> pairs;
233-
string prev_char = word[0];
234-
for (unsigned i = 1; i < word.size(); i++) {
235-
pairs.emplace_back(pair<string, string>(prev_char, word[i]));
236-
prev_char = word[i];
233+
string prev_char = (*word)[0];
234+
for (unsigned i = 1; i < word->size(); i++) {
235+
pairs.emplace_back(pair<string, string>(prev_char, (*word)[i]));
236+
prev_char = (*word)[i];
237237
}
238238
return pairs;
239239
}
@@ -246,25 +246,27 @@ vector<pair<string, string>> getPairs(vector<string> word) {
246246
// bpe_cache: BPE converted words
247247
// Returns:
248248
// BPE words
249-
vector<string> encode(string orig, map<pair<string, string>, unsigned> bpe_codes,
249+
vector<string> encode(const string * orig,
250+
const map<pair<string, string>, unsigned> * bpe_codes,
250251
map<string, vector<string>> * bpe_cache) {
251252
// if exists in bpe_cache
252-
const auto &entry = bpe_cache->find(orig);
253+
const auto &entry = bpe_cache->find(*orig);
253254
if (entry != bpe_cache->end()) {
254255
return entry->second;
255256
}
256257

257-
vector<string> word = UTF8::getLetters(orig);
258+
vector<string> word = UTF8::getLetters(*orig);
258259
word.emplace_back("</w>");
259-
vector<pair<string, string>> pairs = getPairs(word);
260+
vector<pair<string, string>> pairs = getPairs(&word);
260261

261262
while (true) {
262263
unsigned min_bigram = UINT_MAX;
263264
unsigned argmin_bigram = 0;
264265
for (unsigned i = 0; i < pairs.size(); i++) {
265-
if (bpe_codes.find(pairs[i]) != bpe_codes.end() and
266-
bpe_codes[pairs[i]] < min_bigram) {
267-
min_bigram = bpe_codes[pairs[i]];
266+
const auto &entry = bpe_codes->find(pairs[i]);
267+
if (entry != bpe_codes->end() and
268+
entry->second < min_bigram) {
269+
min_bigram = entry->second;
268270
argmin_bigram = i;
269271
}
270272
}
@@ -276,7 +278,7 @@ vector<string> encode(string orig, map<pair<string, string>, unsigned> bpe_codes
276278
vector<string> new_word;
277279
unsigned i = 0;
278280
while (i < word.size()) {
279-
unsigned j = findIndex(word, first, i);
281+
unsigned j = findIndex(&word, &first, i);
280282
if (j == word.size()) {
281283
copy(word.begin() + i, word.end(), back_inserter(new_word));
282284
break;
@@ -296,11 +298,11 @@ vector<string> encode(string orig, map<pair<string, string>, unsigned> bpe_codes
296298
if (word.size() == 1) {
297299
break;
298300
} else {
299-
pairs = getPairs(word);
301+
pairs = getPairs(&word);
300302
}
301303
}
302304

303-
(*bpe_cache)[orig] = word;
305+
(*bpe_cache)[*orig] = word;
304306
return word;
305307
}
306308

@@ -379,22 +381,22 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) {
379381

380382
map<vector<string>, int> stats;
381383
map<vector<string>, map<unsigned, int>> indices;
382-
getPairStatistics(vector_vocab, stats, indices);
384+
getPairStatistics(&vector_vocab, &stats, &indices);
383385
map<vector<string>, int> big_stats = stats;
384-
int threshold = stats[findMax(stats)] / 10;
386+
int threshold = stats[findMax(&stats)] / 10;
385387

386388
unsigned num_letter_vocab = stoi_.size();
387389
for (unsigned i = 0; i < size - num_letter_vocab; i++) {
388390
vector<string> most_frequent_index;
389391
if (!stats.empty()) {
390-
most_frequent_index = findMax(stats);
392+
most_frequent_index = findMax(&stats);
391393
}
392394
if (stats.empty() or (i != 0 and stats[most_frequent_index] < threshold)) {
393-
pruneStats(stats, big_stats, threshold);
395+
pruneStats(&stats, &big_stats, threshold);
394396
stats = big_stats;
395-
vector<string> most_frequent_index = findMax(stats);
397+
most_frequent_index = findMax(&stats);
396398
threshold = stats[most_frequent_index] * i/(i+10000.0);
397-
pruneStats(stats, big_stats, threshold);
399+
pruneStats(&stats, &big_stats, threshold);
398400
}
399401

400402
// Store entries
@@ -411,12 +413,12 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) {
411413
}
412414

413415
vector<Change> changes =
414-
replacePair(most_frequent_index, vector_vocab, indices[most_frequent_index]);
415-
updatePairStatistics(most_frequent_index, changes, stats, indices);
416+
replacePair(&most_frequent_index, &vector_vocab, &indices[most_frequent_index]);
417+
updatePairStatistics(&most_frequent_index, &changes, &stats, &indices);
416418
stats[most_frequent_index] = 0;
417419

418420
if (i % 100 == 0) {
419-
pruneStats(stats, big_stats, threshold);
421+
pruneStats(&stats, &big_stats, threshold);
420422
}
421423
}
422424
// end making BPE codes
@@ -444,7 +446,7 @@ vector<unsigned> BPEVocabulary::convertToIDs(const string & sentence) const {
444446
words, sentence, boost::is_space(), boost::algorithm::token_compress_on);
445447
vector<unsigned> ids;
446448
for (const string & word : words) {
447-
vector<string> new_words = encode(word, bpe_codes_, &bpe_cache_);
449+
vector<string> new_words = encode(&word, &bpe_codes_, &bpe_cache_);
448450
for (const string & new_word : new_words) {
449451
ids.emplace_back(getID(new_word));
450452
}

0 commit comments

Comments
 (0)