@@ -32,19 +32,19 @@ struct Change {
32
32
// vocab: vector bigram frequency.
33
33
// stats: sum of bigram frequency.
34
34
// indices: index of stats (key=bigram)
35
- void getPairStatistics (vector<pair<vector<string>, int >> vocab,
36
- map<vector<string>, int > & stats,
37
- map<vector<string>, map<unsigned , int >> & indices) {
35
+ void getPairStatistics (vector<pair<vector<string>, int >> * vocab,
36
+ map<vector<string>, int > * stats,
37
+ map<vector<string>, map<unsigned , int >> * indices) {
38
38
39
- for (unsigned i = 0 ; i < vocab. size (); i++) {
40
- const vector<string> word = vocab[i].first ;
41
- const int freq = vocab[i].second ;
39
+ for (unsigned i = 0 ; i < vocab-> size (); i++) {
40
+ const vector<string> word = (* vocab) [i].first ;
41
+ const int freq = (* vocab) [i].second ;
42
42
string prev_char = word[0 ];
43
43
for (unsigned j = 1 ; j < word.size (); j++) {
44
44
const string current_char = word[j];
45
45
const vector<string> key = {prev_char, current_char};
46
- stats[key] += freq;
47
- indices[key][i] += 1 ;
46
+ (* stats) [key] += freq;
47
+ (* indices) [key][i] += 1 ;
48
48
prev_char = current_char;
49
49
}
50
50
}
@@ -57,10 +57,10 @@ void getPairStatistics(vector<pair<vector<string>, int>> vocab,
57
57
//
58
58
// Returns:
59
59
// most frequent bigram
60
- vector<string> findMax (const map<vector<string>, int > stats) {
60
+ vector<string> findMax (const map<vector<string>, int > * stats) {
61
61
int current_max = -1e5 ;
62
62
vector<string> current_argmax;
63
- for (auto elm : stats) {
63
+ for (auto elm : (* stats) ) {
64
64
if (elm.second > current_max) {
65
65
current_max = elm.second ;
66
66
current_argmax = elm.first ;
@@ -78,29 +78,29 @@ vector<string> findMax(const map<vector<string>, int> stats) {
78
78
//
79
79
// Returns:
80
80
// vector of replaceable pairs
81
- vector<Change> replacePair (const vector<string> replace_words,
82
- vector<pair<vector<string>, int >> & vocab,
83
- const map<unsigned , int > indices) {
84
- string first = replace_words[0 ];
85
- string second = replace_words[1 ];
86
- string pair_str = boost::join (replace_words, " " );
81
+ vector<Change> replacePair (const vector<string> * replace_words,
82
+ vector<pair<vector<string>, int >> * vocab,
83
+ const map<unsigned , int > * indices) {
84
+ string first = (* replace_words) [0 ];
85
+ string second = (* replace_words) [1 ];
86
+ string pair_str = boost::join ((* replace_words) , " " );
87
87
vector<Change> changes;
88
88
89
- for (const auto index : indices) {
89
+ for (const auto index : (* indices) ) {
90
90
unsigned j = index .first ;
91
91
int freq = index .second ;
92
92
if (freq < 1 ) {
93
93
continue ;
94
94
}
95
95
96
- vector<string> word = vocab[j].first ;
97
- freq = vocab[j].second ;
96
+ vector<string> word = (* vocab) [j].first ;
97
+ freq = (* vocab) [j].second ;
98
98
string new_word = boost::join (word, " " );
99
99
boost::replace_all (new_word, first + " " + second, pair_str);
100
100
vector<string> vector_new_word;
101
101
boost::split (
102
102
vector_new_word, new_word, boost::is_space (), boost::algorithm::token_compress_on);
103
- vocab[j] = pair<vector<string>, int >(vector_new_word, freq);
103
+ (* vocab) [j] = pair<vector<string>, int >(vector_new_word, freq);
104
104
changes.emplace_back ( Change{j, vector_new_word, word, freq} );
105
105
}
106
106
@@ -116,9 +116,9 @@ vector<Change> replacePair(const vector<string> replace_words,
116
116
//
117
117
// Returns:
118
118
// index of the specific word
119
- int findIndex (vector<string> word, string search_word, unsigned start_index) {
120
- auto iter = find (word. begin () + start_index, word. end (), search_word);
121
- size_t index = distance (word. begin (), iter);
119
+ int findIndex (vector<string> * word, string * search_word, unsigned start_index) {
120
+ auto iter = find (word-> begin () + start_index, word-> end (), (* search_word) );
121
+ size_t index = distance (word-> begin (), iter);
122
122
return index ;
123
123
}
124
124
@@ -129,39 +129,39 @@ int findIndex(vector<string> word, string search_word, unsigned start_index) {
129
129
// changes: return value of replacePair()
130
130
// stats: sum of bigram frequency.
131
131
// indices: index of stats (key=bigram)
132
- void updatePairStatistics (const vector<string> replace_words,
133
- const vector<Change> changes,
134
- map<vector<string>, int > & stats,
135
- map<vector<string>, map<unsigned , int >> & indices) {
136
- stats. erase (replace_words);
137
- indices. erase (replace_words);
138
- string first = replace_words[0 ];
139
- string second = replace_words[1 ];
132
+ void updatePairStatistics (const vector<string> * replace_words,
133
+ const vector<Change> * changes,
134
+ map<vector<string>, int > * stats,
135
+ map<vector<string>, map<unsigned , int >> * indices) {
136
+ stats-> erase ((* replace_words) );
137
+ indices-> erase ((* replace_words) );
138
+ string first = (* replace_words) [0 ];
139
+ string second = (* replace_words) [1 ];
140
140
string new_pair = first + second;
141
141
142
- for (unsigned i = 0 ; i < changes. size (); i++) {
143
- unsigned j = changes[i].index ;
144
- vector<string> new_word = changes[i].new_word ;
145
- vector<string> old_word = changes[i].old_word ;
146
- int freq = changes[i].freq ;
142
+ for (unsigned i = 0 ; i < changes-> size (); i++) {
143
+ unsigned j = (* changes) [i].index ;
144
+ vector<string> new_word = (* changes) [i].new_word ;
145
+ vector<string> old_word = (* changes) [i].old_word ;
146
+ int freq = (* changes) [i].freq ;
147
147
148
148
unsigned k = 0 ;
149
149
while (true ) {
150
- k = findIndex (old_word, first, k);
150
+ k = findIndex (& old_word, & first, k);
151
151
if (k == old_word.size ()) {
152
152
break ;
153
153
}
154
154
if (k < old_word.size () - 1 and old_word[k+1 ] == second) {
155
155
if (k != 0 ) {
156
156
vector<string> prev = {old_word[k-1 ], old_word[k]};
157
- stats[prev] -= freq;
158
- indices[prev][j] -= 1 ;
157
+ (* stats) [prev] -= freq;
158
+ (* indices) [prev][j] -= 1 ;
159
159
}
160
160
if (k < old_word.size () - 2 ) {
161
161
if (old_word[k+2 ] != first or k >= old_word.size () - 3 or old_word[k+3 ] != second) {
162
162
vector<string> nex = {old_word[k+1 ], old_word[k+2 ]};
163
- stats[nex] -= freq;
164
- indices[nex][j] -= 1 ;
163
+ (* stats) [nex] -= freq;
164
+ (* indices) [nex][j] -= 1 ;
165
165
}
166
166
}
167
167
k += 2 ;
@@ -173,19 +173,19 @@ void updatePairStatistics(const vector<string> replace_words,
173
173
174
174
k = 0 ;
175
175
while (true ) {
176
- k = findIndex (new_word, new_pair, k);
176
+ k = findIndex (& new_word, & new_pair, k);
177
177
if (k == new_word.size ()) {
178
178
break ;
179
179
}
180
180
if (k != 0 ) {
181
181
vector<string> prev = {new_word[k-1 ], new_word[k]};
182
- stats[prev] += freq;
183
- indices[prev][j] += 1 ;
182
+ (* stats) [prev] += freq;
183
+ (* indices) [prev][j] += 1 ;
184
184
}
185
185
if (k < new_word.size () - 1 and new_word[k+1 ] != new_pair) {
186
186
vector<string> nex = {new_word[k], new_word[k+1 ]};
187
- stats[nex] += freq;
188
- indices[nex][j] += 1 ;
187
+ (* stats) [nex] += freq;
188
+ (* indices) [nex][j] += 1 ;
189
189
}
190
190
k += 1 ;
191
191
}
@@ -199,20 +199,20 @@ void updatePairStatistics(const vector<string> replace_words,
199
199
// big_stats: sum of bigram frequency (not pruned).
200
200
// threshold: words that frequency is less than this threshold will be pruned
201
201
void pruneStats (
202
- map<vector<string>, int > & stats,
203
- map<vector<string>, int > & big_stats,
202
+ map<vector<string>, int > * stats,
203
+ map<vector<string>, int > * big_stats,
204
204
const int threshold) {
205
- map<vector<string>, int >::iterator it = stats. begin ();
206
- while (it != stats. end ()) {
205
+ map<vector<string>, int >::iterator it = stats-> begin ();
206
+ while (it != stats-> end ()) {
207
207
vector<string> item = it->first ;
208
208
int freq = it->second ;
209
209
if (freq < threshold) {
210
- stats. erase (it++);
210
+ stats-> erase (it++);
211
211
if (freq < 0 ) {
212
- big_stats[item] += freq;
212
+ (* big_stats) [item] += freq;
213
213
}
214
214
else {
215
- big_stats[item] = freq;
215
+ (* big_stats) [item] = freq;
216
216
}
217
217
}
218
218
else {
@@ -228,12 +228,12 @@ void pruneStats(
228
228
//
229
229
// Returns:
230
230
// pairs of character bigram
231
- vector<pair<string, string>> getPairs (vector<string> word) {
231
+ vector<pair<string, string>> getPairs (vector<string> * word) {
232
232
vector<pair<string, string>> pairs;
233
- string prev_char = word[0 ];
234
- for (unsigned i = 1 ; i < word. size (); i++) {
235
- pairs.emplace_back (pair<string, string>(prev_char, word[i]));
236
- prev_char = word[i];
233
+ string prev_char = (* word) [0 ];
234
+ for (unsigned i = 1 ; i < word-> size (); i++) {
235
+ pairs.emplace_back (pair<string, string>(prev_char, (* word) [i]));
236
+ prev_char = (* word) [i];
237
237
}
238
238
return pairs;
239
239
}
@@ -246,25 +246,27 @@ vector<pair<string, string>> getPairs(vector<string> word) {
246
246
// bpe_cache: BPE converted words
247
247
// Returns:
248
248
// BPE words
249
- vector<string> encode (string orig, map<pair<string, string>, unsigned > bpe_codes,
249
+ vector<string> encode (const string * orig,
250
+ const map<pair<string, string>, unsigned > * bpe_codes,
250
251
map<string, vector<string>> * bpe_cache) {
251
252
// if exists in bpe_cache
252
- const auto &entry = bpe_cache->find (orig);
253
+ const auto &entry = bpe_cache->find (* orig);
253
254
if (entry != bpe_cache->end ()) {
254
255
return entry->second ;
255
256
}
256
257
257
- vector<string> word = UTF8::getLetters (orig);
258
+ vector<string> word = UTF8::getLetters (* orig);
258
259
word.emplace_back (" </w>" );
259
- vector<pair<string, string>> pairs = getPairs (word);
260
+ vector<pair<string, string>> pairs = getPairs (& word);
260
261
261
262
while (true ) {
262
263
unsigned min_bigram = UINT_MAX;
263
264
unsigned argmin_bigram = 0 ;
264
265
for (unsigned i = 0 ; i < pairs.size (); i++) {
265
- if (bpe_codes.find (pairs[i]) != bpe_codes.end () and
266
- bpe_codes[pairs[i]] < min_bigram) {
267
- min_bigram = bpe_codes[pairs[i]];
266
+ const auto &entry = bpe_codes->find (pairs[i]);
267
+ if (entry != bpe_codes->end () and
268
+ entry->second < min_bigram) {
269
+ min_bigram = entry->second ;
268
270
argmin_bigram = i;
269
271
}
270
272
}
@@ -276,7 +278,7 @@ vector<string> encode(string orig, map<pair<string, string>, unsigned> bpe_codes
276
278
vector<string> new_word;
277
279
unsigned i = 0 ;
278
280
while (i < word.size ()) {
279
- unsigned j = findIndex (word, first, i);
281
+ unsigned j = findIndex (& word, & first, i);
280
282
if (j == word.size ()) {
281
283
copy (word.begin () + i, word.end (), back_inserter (new_word));
282
284
break ;
@@ -296,11 +298,11 @@ vector<string> encode(string orig, map<pair<string, string>, unsigned> bpe_codes
296
298
if (word.size () == 1 ) {
297
299
break ;
298
300
} else {
299
- pairs = getPairs (word);
301
+ pairs = getPairs (& word);
300
302
}
301
303
}
302
304
303
- (*bpe_cache)[orig] = word;
305
+ (*bpe_cache)[* orig] = word;
304
306
return word;
305
307
}
306
308
@@ -379,22 +381,22 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) {
379
381
380
382
map<vector<string>, int > stats;
381
383
map<vector<string>, map<unsigned , int >> indices;
382
- getPairStatistics (vector_vocab, stats, indices);
384
+ getPairStatistics (& vector_vocab, & stats, & indices);
383
385
map<vector<string>, int > big_stats = stats;
384
- int threshold = stats[findMax (stats)] / 10 ;
386
+ int threshold = stats[findMax (& stats)] / 10 ;
385
387
386
388
unsigned num_letter_vocab = stoi_.size ();
387
389
for (unsigned i = 0 ; i < size - num_letter_vocab; i++) {
388
390
vector<string> most_frequent_index;
389
391
if (!stats.empty ()) {
390
- most_frequent_index = findMax (stats);
392
+ most_frequent_index = findMax (& stats);
391
393
}
392
394
if (stats.empty () or (i != 0 and stats[most_frequent_index] < threshold)) {
393
- pruneStats (stats, big_stats, threshold);
395
+ pruneStats (& stats, & big_stats, threshold);
394
396
stats = big_stats;
395
- vector<string> most_frequent_index = findMax (stats);
397
+ most_frequent_index = findMax (& stats);
396
398
threshold = stats[most_frequent_index] * i/(i+10000.0 );
397
- pruneStats (stats, big_stats, threshold);
399
+ pruneStats (& stats, & big_stats, threshold);
398
400
}
399
401
400
402
// Store entries
@@ -411,12 +413,12 @@ BPEVocabulary::BPEVocabulary(const string & corpus_filename, unsigned size) {
411
413
}
412
414
413
415
vector<Change> changes =
414
- replacePair (most_frequent_index, vector_vocab, indices[most_frequent_index]);
415
- updatePairStatistics (most_frequent_index, changes, stats, indices);
416
+ replacePair (& most_frequent_index, & vector_vocab, & indices[most_frequent_index]);
417
+ updatePairStatistics (& most_frequent_index, & changes, & stats, & indices);
416
418
stats[most_frequent_index] = 0 ;
417
419
418
420
if (i % 100 == 0 ) {
419
- pruneStats (stats, big_stats, threshold);
421
+ pruneStats (& stats, & big_stats, threshold);
420
422
}
421
423
}
422
424
// end making BPE codes
@@ -444,7 +446,7 @@ vector<unsigned> BPEVocabulary::convertToIDs(const string & sentence) const {
444
446
words, sentence, boost::is_space (), boost::algorithm::token_compress_on);
445
447
vector<unsigned > ids;
446
448
for (const string & word : words) {
447
- vector<string> new_words = encode (word, bpe_codes_, &bpe_cache_);
449
+ vector<string> new_words = encode (& word, & bpe_codes_, &bpe_cache_);
448
450
for (const string & new_word : new_words) {
449
451
ids.emplace_back (getID (new_word));
450
452
}
0 commit comments