11
11
from utils .base import *
12
12
from utils .init import init_model_params
13
13
from utils .contpara import get_model_parameters
14
- from utils .h5serial import h5save , h5load
15
- from utils .fmt .base import tostr , save_states , load_states , pad_id
14
+ from utils .state .holder import Holder
15
+ from utils .state .pyrand import PyRandomState
16
+ from utils .state .thrand import THRandomState
17
+ from utils .fmt .base import tostr , pad_id
16
18
from utils .fmt .base4torch import parse_cuda , load_emb
17
19
from utils .mulang import data_sampler
18
20
30
32
31
33
from transformer .NMT import NMT
32
34
33
- def train (td , tl , ed , nd , optm , lrsch , model , lossf , mv_device , logger , done_tokens , multi_gpu , multi_gpu_optimizer , tokens_optm = 32768 , nreport = None , save_every = None , chkpf = None , chkpof = None , statesf = None , num_checkpoint = 1 , cur_checkid = 0 , report_eva = True , remain_steps = None , save_loss = False , save_checkp_epoch = False , scaler = None ):
35
+ def train (td , tl , ed , nd , optm , lrsch , model , lossf , mv_device , logger , done_tokens , multi_gpu , multi_gpu_optimizer , tokens_optm = 32768 , nreport = None , save_every = None , chkpf = None , state_holder = None , statesf = None , num_checkpoint = 1 , cur_checkid = 0 , report_eva = True , remain_steps = None , save_loss = False , save_checkp_epoch = False , scaler = None ):
34
36
35
37
sum_loss = part_loss = 0.0
36
38
sum_wd = part_wd = 0
@@ -77,17 +79,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
77
79
if num_checkpoint > 1 :
78
80
_fend = "_%d.h5" % (_cur_checkid )
79
81
_chkpf = chkpf [:- 3 ] + _fend
80
- if chkpof is not None :
81
- _chkpof = chkpof [:- 3 ] + _fend
82
82
_cur_checkid = (_cur_checkid + 1 ) % num_checkpoint
83
83
else :
84
84
_chkpf = chkpf
85
- _chkpof = chkpof
86
85
save_model (model , _chkpf , multi_gpu , print_func = logger .info )
87
- if chkpof is not None :
88
- h5save (optm .state_dict (), _chkpof )
89
86
if statesf is not None :
90
- save_states (statesf , tl [cur_b - 1 :])
87
+ save_states (state_holder . state_dict ( update = False , ** { "remain_steps" : _cur_rstep , "checkpoint_id" : _cur_checkid , "training_list" : tl [cur_b - 1 :]}), statesf , print_func = logger . info )
91
88
_cur_rstep -= 1
92
89
if _cur_rstep <= 0 :
93
90
break
@@ -111,17 +108,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
111
108
if num_checkpoint > 1 :
112
109
_fend = "_%d.h5" % (_cur_checkid )
113
110
_chkpf = chkpf [:- 3 ] + _fend
114
- if chkpof is not None :
115
- _chkpof = chkpof [:- 3 ] + _fend
116
111
_cur_checkid = (_cur_checkid + 1 ) % num_checkpoint
117
112
else :
118
113
_chkpf = chkpf
119
- _chkpof = chkpof
120
114
save_model (model , _chkpf , multi_gpu , print_func = logger .info )
121
- if chkpof is not None :
122
- h5save (optm .state_dict (), _chkpof )
123
115
if statesf is not None :
124
- save_states (statesf , tl [cur_b - 1 :])
116
+ save_states (state_holder . state_dict ( update = False , ** { "remain_steps" : _cur_rstep , "checkpoint_id" : _cur_checkid , "training_list" : tl [cur_b - 1 :]}), statesf , print_func = logger . info )
125
117
cur_b += 1
126
118
if part_wd != 0.0 :
127
119
logger .info ("Average loss over %d tokens: %.3f" % (part_wd , part_loss / part_wd ,))
@@ -181,7 +173,7 @@ def load_fixing(module):
181
173
batch_report = cnfg .batch_report
182
174
report_eva = cnfg .report_eva
183
175
use_ams = cnfg .use_ams
184
- save_optm_state = cnfg .save_optm_state
176
+ cnt_states = cnfg .train_statesf
185
177
save_auto_clean = cnfg .save_auto_clean
186
178
overwrite_eva = cnfg .overwrite_eva
187
179
save_every = cnfg .save_every
@@ -193,14 +185,11 @@ def load_fixing(module):
193
185
mkdir (wkdir )
194
186
195
187
chkpf = None
196
- chkpof = None
197
188
statesf = None
198
189
if save_every is not None :
199
190
chkpf = wkdir + "checkpoint.h5"
200
- if save_optm_state :
201
- chkpof = wkdir + "checkpoint.optm.h5"
202
- if cnfg .save_train_state :
203
- statesf = wkdir + "checkpoint.states"
191
+ if cnfg .save_train_state :
192
+ statesf = wkdir + "train.states.t7"
204
193
205
194
logger = get_logger (wkdir + "train.log" )
206
195
@@ -217,10 +206,6 @@ def load_fixing(module):
217
206
nword = td ["nword" ][:].tolist ()
218
207
nwordi , ntask , nwordt = nword [0 ], nword [1 ], nword [- 1 ]
219
208
220
- logger .info ("Design models with seed: %d" % torch .initial_seed ())
221
- mymodel = NMT (cnfg .isize , nwordi , nwordt , cnfg .nlayer , cnfg .ff_hsize , cnfg .drop , cnfg .attn_drop , cnfg .share_emb , cnfg .nhead , cache_len_default , cnfg .attn_hsize , cnfg .norm_output , cnfg .bindDecoderEmb , cnfg .forbidden_indexes )
222
-
223
- fine_tune_m = cnfg .fine_tune_m
224
209
task_weight , task_weight_T = cnfg .task_weight , cnfg .task_weight_T
225
210
if task_weight_T is None or task_weight_T == 1.0 :
226
211
tl = [(str (i ), _task ,) for _nd , _task in zip (ntrain , td ["taskorder" ][:].tolist ()) for i in range (_nd )]
@@ -234,6 +219,11 @@ def load_fixing(module):
234
219
train_sampler = data_sampler (ntrain if task_weight is None else task_weight , task_weight_T , ntrain , train_taskorder , nsample = sum (ntrain ))
235
220
nvalid = [(str (i ), _task ,) for _nd , _task in zip (nvalid , vd ["taskorder" ][:].tolist ()) for i in range (_nd )]
236
221
222
+ logger .info ("Design models with seed: %d" % torch .initial_seed ())
223
+ mymodel = NMT (cnfg .isize , nwordi , nwordt , cnfg .nlayer , cnfg .ff_hsize , cnfg .drop , cnfg .attn_drop , cnfg .share_emb , cnfg .nhead , cache_len_default , cnfg .attn_hsize , cnfg .norm_output , cnfg .bindDecoderEmb , cnfg .forbidden_indexes )
224
+
225
+ fine_tune_m = cnfg .fine_tune_m
226
+
237
227
mymodel = init_model_params (mymodel )
238
228
mymodel .apply (init_fixing )
239
229
if fine_tune_m is not None :
@@ -267,13 +257,10 @@ def load_fixing(module):
267
257
optimizer = Optimizer (get_model_parameters (mymodel , contiguous_parameters = contiguous_parameters ), lr = init_lr , betas = adam_betas_default , eps = ieps_adam_default , weight_decay = cnfg .weight_decay , amsgrad = use_ams )
268
258
optimizer .zero_grad (set_to_none = optm_step_zero_grad_set_none )
269
259
270
- fine_tune_state = cnfg .fine_tune_state
271
- if fine_tune_state is not None :
272
- logger .info ("Load optimizer state from: " + fine_tune_state )
273
- optimizer .load_state_dict (h5load (fine_tune_state ))
274
-
275
260
lrsch = LRScheduler (optimizer , cnfg .isize , cnfg .warm_step , scale = cnfg .lr_scale )
276
261
262
+ state_holder = None if statesf is None and cnt_states is None else Holder (** {"optm" : optimizer , "lrsch" : lrsch , "pyrand" : PyRandomState (), "thrand" : THRandomState (use_cuda = use_cuda )})
263
+
277
264
num_checkpoint = cnfg .num_checkpoint
278
265
cur_checkid = 0
279
266
@@ -286,15 +273,22 @@ def load_fixing(module):
286
273
save_model (mymodel , wkdir + "init.h5" , multi_gpu , print_func = logger .info )
287
274
logger .info ("Initial model saved" )
288
275
else :
289
- cnt_states = cnfg .train_statesf
290
276
if cnt_states is not None :
291
- logger .info ("Continue last epoch" )
292
- tminerr , done_tokens , cur_checkid , remain_steps , _ = train (td , load_states (cnt_states ), vd , nvalid , optimizer , lrsch , mymodel , lossf , cuda_device , logger , done_tokens , multi_gpu , multi_gpu_optimizer , tokens_optm , batch_report , save_every , chkpf , chkpof , statesf , num_checkpoint , cur_checkid , report_eva , remain_steps , False , False , scaler )
277
+ logger .info ("Loading training states" )
278
+ _remain_states = state_holder .load_state_dict (torch .load (cnt_states ))
279
+ remain_steps , cur_checkid = _remain_states ["remain_steps" ], _remain_states ["checkpoint_id" ]
280
+ if "training_list" in _remain_states :
281
+ _ctl = _remain_states ["training_list" ]
282
+ else :
283
+ shuffle (tl )
284
+ _ctl = tl
285
+ tminerr , done_tokens , cur_checkid , remain_steps , _ = train (td , _ctl , vd , nvalid , optimizer , lrsch , mymodel , lossf , cuda_device , logger , done_tokens , multi_gpu , multi_gpu_optimizer , tokens_optm , batch_report , save_every , chkpf , state_holder , statesf , num_checkpoint , cur_checkid , report_eva , remain_steps , False , False , scaler )
286
+ _ctl = _remain_states = None
293
287
vloss , vprec = eva (vd , nvalid , mymodel , lossf , cuda_device , multi_gpu , use_amp )
294
288
logger .info ("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr , vloss , vprec ,))
295
289
save_model (mymodel , wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr , vloss , vprec ,), multi_gpu , print_func = logger .info , mtyp = ("eva" if overwrite_eva else "train" ) if save_auto_clean else None )
296
- if save_optm_state :
297
- h5save ( optimizer .state_dict (), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % ( tminerr , vloss , vprec ,) )
290
+ if statesf is not None :
291
+ save_states ( state_holder .state_dict (update = False , ** { "remain_steps" : remain_steps , "checkpoint_id" : cur_checkid }), statesf , print_func = logger . info )
298
292
logger .info ("New best model saved" )
299
293
300
294
if cnfg .dss_ws is not None and cnfg .dss_ws > 0.0 and cnfg .dss_ws < 1.0 :
@@ -319,14 +313,14 @@ def load_fixing(module):
319
313
else :
320
314
tl = train_sampler .generate ()
321
315
free_cache (use_cuda )
322
- terr , done_tokens , cur_checkid , remain_steps , _Dws = train (td , tl , vd , nvalid , optimizer , lrsch , mymodel , lossf , cuda_device , logger , done_tokens , multi_gpu , multi_gpu_optimizer , tokens_optm , batch_report , save_every , chkpf , chkpof , statesf , num_checkpoint , cur_checkid , report_eva , remain_steps , dss_ws > 0 , i >= start_chkp_save , scaler )
316
+ terr , done_tokens , cur_checkid , remain_steps , _Dws = train (td , tl , vd , nvalid , optimizer , lrsch , mymodel , lossf , cuda_device , logger , done_tokens , multi_gpu , multi_gpu_optimizer , tokens_optm , batch_report , save_every , chkpf , state_holder , statesf , num_checkpoint , cur_checkid , report_eva , remain_steps , dss_ws > 0 , i >= start_chkp_save , scaler )
323
317
vloss , vprec = eva (vd , nvalid , mymodel , lossf , cuda_device , multi_gpu , use_amp )
324
318
logger .info ("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i , terr , vloss , vprec ,))
325
319
326
320
if (vprec <= minerr ) or (vloss <= minloss ):
327
321
save_model (mymodel , wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i , terr , vloss , vprec ,), multi_gpu , print_func = logger .info , mtyp = "eva" if save_auto_clean else None )
328
- if save_optm_state :
329
- h5save ( optimizer .state_dict (), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % ( i , terr , vloss , vprec ,) )
322
+ if statesf is not None :
323
+ save_states ( state_holder .state_dict (update = False , ** { "remain_steps" : remain_steps , "checkpoint_id" : cur_checkid }), statesf , print_func = logger . info )
330
324
logger .info ("New best model saved" )
331
325
332
326
namin = 0
@@ -340,15 +334,18 @@ def load_fixing(module):
340
334
if terr < tminerr :
341
335
tminerr = terr
342
336
save_model (mymodel , wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i , terr , vloss , vprec ,), multi_gpu , print_func = logger .info , mtyp = ("eva" if overwrite_eva else "train" ) if save_auto_clean else None )
343
- if save_optm_state :
344
- h5save ( optimizer .state_dict (), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % ( i , terr , vloss , vprec ,) )
337
+ if statesf is not None :
338
+ save_states ( state_holder .state_dict (update = False , ** { "remain_steps" : remain_steps , "checkpoint_id" : cur_checkid }), statesf , print_func = logger . info )
345
339
elif epoch_save :
346
340
save_model (mymodel , wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i , terr , vloss , vprec ,), multi_gpu , print_func = logger .info )
341
+ if statesf is not None :
342
+ save_states (state_holder .state_dict (update = False , ** {"remain_steps" : remain_steps , "checkpoint_id" : cur_checkid }), statesf , print_func = logger .info )
347
343
348
344
namin += 1
349
345
if namin >= earlystop :
350
346
if done_tokens > 0 :
351
347
optm_step (optimizer , model = mymodel , scaler = scaler , multi_gpu = multi_gpu , multi_gpu_optimizer = multi_gpu_optimizer )
348
+ lrsch .step ()
352
349
done_tokens = 0
353
350
logger .info ("early stop" )
354
351
break
@@ -368,10 +365,11 @@ def load_fixing(module):
368
365
369
366
if done_tokens > 0 :
370
367
optm_step (optimizer , model = mymodel , scaler = scaler , multi_gpu = multi_gpu , multi_gpu_optimizer = multi_gpu_optimizer )
368
+ lrsch .step ()
371
369
372
370
save_model (mymodel , wkdir + "last.h5" , multi_gpu , print_func = logger .info )
373
- if save_optm_state :
374
- h5save ( optimizer .state_dict (), wkdir + "last.optm.h5" )
371
+ if statesf is not None :
372
+ save_states ( state_holder .state_dict (update = False , ** { "remain_steps" : remain_steps , "checkpoint_id" : cur_checkid }), statesf , print_func = logger . info )
375
373
logger .info ("model saved" )
376
374
377
375
td .close ()
0 commit comments