@@ -139,6 +139,28 @@ class NMTSampler:
139
139
def __init__ (self , models , dataset , params , params_prediction , params_training , model_tokenize_f , model_detokenize_f , general_tokenize_f ,
140
140
general_detokenize_f , mapping = None , word2index_x = None , word2index_y = None , index2word_y = None ,
141
141
excluded_words = None , unk_id = 1 , eos_symbol = '/' , online = False , verbose = 0 ):
142
+ """
143
+ Builds an NMTSampler: An object containing models and dataset, for the interactive-predictive and adaptive framework.
144
+ :param models:
145
+ :param dataset:
146
+ :param dict params: All hyperparameters of the model.
147
+ :param dict params_prediction: Hyperparameters regarding prediction and search.
148
+ :param dict params_training: Hyperparamters regarding incremental training.
149
+ :param function model_tokenize_f: Function used for tokenizing the input sentence. E.g. BPE.
150
+ :param function model_detokenize_f: Function used for detokenizing the output sentence. E.g. BPE revert.
151
+ :param function general_tokenize_f: Function used for tokenizing the input sentence. E.g. Moses tokenizer.
152
+ :param function general_detokenize_f: Function used for detokenizing the output sentence. E.g. Moses detokenizer.
153
+ :param dict mapping: Source-target dictionary (for unk_replace heuristics 1 and 2).
154
+ :param dict word2index_x: Mapping from word strings into indices for the source language.
155
+ :param dict word2index_y: Mapping from word strings into indices for the target language.
156
+ :param dict index2word_y: Mapping from indices into word strings for the target language.
157
+ :param dict excluded_words: words that won't be generated in the middle of two isles. Currenly unused.
158
+ :param int unk_id: Unknown word index.
159
+ :param str eos_symbol: End-of-sentence symbol.
160
+ :param bool online: Whether apply online learning after accepting each hypothesis.
161
+ :param int verbose: Verbosity level.
162
+ """
163
+
142
164
self .models = models
143
165
self .dataset = dataset
144
166
self .params = params
@@ -165,7 +187,7 @@ def __init__(self, models, dataset, params, params_prediction, params_training,
165
187
excluded_words = self .excluded_words ,
166
188
verbose = self .verbose )
167
189
168
- # Compile Theano sampling function by generating a fake sample # TODO: Find a better way of doing this
190
+ # Compile Theano sampling function by generating a fake sample. # TODO: Find a better way of doing this
169
191
logger .info ('Compiling sampler...' )
170
192
self .generate_sample ('i' )
171
193
logger .info ('Done.' )
@@ -186,7 +208,18 @@ def __init__(self, models, dataset, params, params_prediction, params_training,
186
208
187
209
def generate_sample (self , source_sentence , validated_prefix = None , max_N = 5 , isle_indices = None ,
188
210
filtered_idx2word = None , unk_indices = None , unk_words = None ):
189
- print ("In params prediction beam_size: " , self .params_prediction ['beam_size' ])
211
+ """
212
+ Generate sample via constrained search. Options labeled with <<isles>> are untested
213
+ and likely require some modifications to correctly work.
214
+ :param source_sentence: Source sentence.
215
+ :param validated_prefix: Prefix to keep in the output.
216
+ :param max_N: Maximum number of words to generate between validated segments. <<isles>>
217
+ :param isle_indices: Indices of the validated segments. <<isles>>
218
+ :param filtered_idx2word: List of candidate words to be the next one to generate (after generating fixed_words).
219
+ :param unk_indices: Positions of the unknown words.
220
+ :param unk_words: Unknown words.
221
+ :return:
222
+ """
190
223
logger .log (2 , 'Beam size: %d' % (self .params_prediction ['beam_size' ]))
191
224
generate_sample_start_time = time .time ()
192
225
if unk_indices is None :
@@ -295,10 +328,6 @@ def generate_sample(self, source_sentence, validated_prefix=None, max_N=5, isle_
295
328
decoding_predictions_end_time = time .time ()
296
329
logger .log (2 , 'decoding_predictions time: %.6f' % (decoding_predictions_end_time - decoding_predictions_start_time ))
297
330
298
- # for (words_idx, starting_pos), words in unk_in_isles:
299
- # for pos_unk_word, pos_hypothesis in enumerate(range(starting_pos, starting_pos + len(words_idx))):
300
- # hypothesis[pos_hypothesis] = words[pos_unk_word]
301
-
302
331
# UNK words management
303
332
unk_management_start_time = time .time ()
304
333
unk_indices = list (unk_words_dict )
@@ -330,7 +359,12 @@ def generate_sample(self, source_sentence, validated_prefix=None, max_N=5, isle_
330
359
return hypothesis
331
360
332
361
def learn_from_sample (self , source_sentence , target_sentence ):
333
-
362
+ """
363
+ Incrementally adapt the model with the validated sample.
364
+ :param source_sentence: Source sentence (x).
365
+ :param target_sentence: Target sentence (y).
366
+ :return:
367
+ """
334
368
# Tokenize input
335
369
tokenized_input = self .general_tokenize_f (source_sentence , escape = False )
336
370
tokenized_input = self .model_tokenize_f (tokenized_input )
@@ -499,7 +533,6 @@ def main():
499
533
for i in range (len (args .models ))]
500
534
models = [updateModel (model , path , - 1 , full_path = True ) for (model , path ) in zip (model_instances , args .models )]
501
535
502
- # Set additional inputs to models if using a custom loss function
503
536
# parameters['USE_CUSTOM_LOSS'] = True if 'PAS' in parameters['OPTIMIZER'] else False
504
537
# if parameters.get('N_BEST_OPTIMIZER', False):
505
538
# logger.info('Using N-best optimizer')
0 commit comments