Skip to content
This repository was archived by the owner on Aug 10, 2023. It is now read-only.

Commit eb77366

Browse files
committed
December 2021 update
1 parent e844c5a commit eb77366

39 files changed

+534
-468
lines changed

adv/train/mulang/train_m2o.py

+40-42
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from utils.base import *
1212
from utils.init import init_model_params
1313
from utils.contpara import get_model_parameters
14-
from utils.h5serial import h5save, h5load
15-
from utils.fmt.base import tostr, save_states, load_states, pad_id
14+
from utils.state.holder import Holder
15+
from utils.state.pyrand import PyRandomState
16+
from utils.state.thrand import THRandomState
17+
from utils.fmt.base import tostr, pad_id
1618
from utils.fmt.base4torch import parse_cuda, load_emb
1719
from utils.mulang import data_sampler
1820

@@ -30,7 +32,7 @@
3032

3133
from transformer.NMT import NMT
3234

33-
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None):
35+
def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None):
3436

3537
sum_loss = part_loss = 0.0
3638
sum_wd = part_wd = 0
@@ -77,17 +79,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
7779
if num_checkpoint > 1:
7880
_fend = "_%d.h5" % (_cur_checkid)
7981
_chkpf = chkpf[:-3] + _fend
80-
if chkpof is not None:
81-
_chkpof = chkpof[:-3] + _fend
8282
_cur_checkid = (_cur_checkid + 1) % num_checkpoint
8383
else:
8484
_chkpf = chkpf
85-
_chkpof = chkpof
8685
save_model(model, _chkpf, multi_gpu, print_func=logger.info)
87-
if chkpof is not None:
88-
h5save(optm.state_dict(), _chkpof)
8986
if statesf is not None:
90-
save_states(statesf, tl[cur_b - 1:])
87+
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
9188
_cur_rstep -= 1
9289
if _cur_rstep <= 0:
9390
break
@@ -111,17 +108,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
111108
if num_checkpoint > 1:
112109
_fend = "_%d.h5" % (_cur_checkid)
113110
_chkpf = chkpf[:-3] + _fend
114-
if chkpof is not None:
115-
_chkpof = chkpof[:-3] + _fend
116111
_cur_checkid = (_cur_checkid + 1) % num_checkpoint
117112
else:
118113
_chkpf = chkpf
119-
_chkpof = chkpof
120114
save_model(model, _chkpf, multi_gpu, print_func=logger.info)
121-
if chkpof is not None:
122-
h5save(optm.state_dict(), _chkpof)
123115
if statesf is not None:
124-
save_states(statesf, tl[cur_b - 1:])
116+
save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info)
125117
cur_b += 1
126118
if part_wd != 0.0:
127119
logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,))
@@ -181,7 +173,7 @@ def load_fixing(module):
181173
batch_report = cnfg.batch_report
182174
report_eva = cnfg.report_eva
183175
use_ams = cnfg.use_ams
184-
save_optm_state = cnfg.save_optm_state
176+
cnt_states = cnfg.train_statesf
185177
save_auto_clean = cnfg.save_auto_clean
186178
overwrite_eva = cnfg.overwrite_eva
187179
save_every = cnfg.save_every
@@ -193,14 +185,11 @@ def load_fixing(module):
193185
mkdir(wkdir)
194186

195187
chkpf = None
196-
chkpof = None
197188
statesf = None
198189
if save_every is not None:
199190
chkpf = wkdir + "checkpoint.h5"
200-
if save_optm_state:
201-
chkpof = wkdir + "checkpoint.optm.h5"
202-
if cnfg.save_train_state:
203-
statesf = wkdir + "checkpoint.states"
191+
if cnfg.save_train_state:
192+
statesf = wkdir + "train.states.t7"
204193

205194
logger = get_logger(wkdir + "train.log")
206195

@@ -217,10 +206,6 @@ def load_fixing(module):
217206
nword = td["nword"][:].tolist()
218207
nwordi, ntask, nwordt = nword[0], nword[1], nword[-1]
219208

220-
logger.info("Design models with seed: %d" % torch.initial_seed())
221-
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
222-
223-
fine_tune_m = cnfg.fine_tune_m
224209
task_weight, task_weight_T = cnfg.task_weight, cnfg.task_weight_T
225210
if task_weight_T is None or task_weight_T == 1.0:
226211
tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][:].tolist()) for i in range(_nd)]
@@ -234,6 +219,11 @@ def load_fixing(module):
234219
train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain))
235220
nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][:].tolist()) for i in range(_nd)]
236221

222+
logger.info("Design models with seed: %d" % torch.initial_seed())
223+
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
224+
225+
fine_tune_m = cnfg.fine_tune_m
226+
237227
mymodel = init_model_params(mymodel)
238228
mymodel.apply(init_fixing)
239229
if fine_tune_m is not None:
@@ -267,13 +257,10 @@ def load_fixing(module):
267257
optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams)
268258
optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none)
269259

270-
fine_tune_state = cnfg.fine_tune_state
271-
if fine_tune_state is not None:
272-
logger.info("Load optimizer state from: " + fine_tune_state)
273-
optimizer.load_state_dict(h5load(fine_tune_state))
274-
275260
lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)
276261

262+
state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)})
263+
277264
num_checkpoint = cnfg.num_checkpoint
278265
cur_checkid = 0
279266

@@ -286,15 +273,22 @@ def load_fixing(module):
286273
save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info)
287274
logger.info("Initial model saved")
288275
else:
289-
cnt_states = cnfg.train_statesf
290276
if cnt_states is not None:
291-
logger.info("Continue last epoch")
292-
tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler)
277+
logger.info("Loading training states")
278+
_remain_states = state_holder.load_state_dict(torch.load(cnt_states))
279+
remain_steps, cur_checkid = _remain_states["remain_steps"], _remain_states["checkpoint_id"]
280+
if "training_list" in _remain_states:
281+
_ctl = _remain_states["training_list"]
282+
else:
283+
shuffle(tl)
284+
_ctl = tl
285+
tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, _ctl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler)
286+
_ctl = _remain_states = None
293287
vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
294288
logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,))
295289
save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None)
296-
if save_optm_state:
297-
h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec,))
290+
if statesf is not None:
291+
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
298292
logger.info("New best model saved")
299293

300294
if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0:
@@ -319,14 +313,14 @@ def load_fixing(module):
319313
else:
320314
tl = train_sampler.generate()
321315
free_cache(use_cuda)
322-
terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler)
316+
terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler)
323317
vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp)
324318
logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,))
325319

326320
if (vprec <= minerr) or (vloss <= minloss):
327321
save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None)
328-
if save_optm_state:
329-
h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,))
322+
if statesf is not None:
323+
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
330324
logger.info("New best model saved")
331325

332326
namin = 0
@@ -340,15 +334,18 @@ def load_fixing(module):
340334
if terr < tminerr:
341335
tminerr = terr
342336
save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None)
343-
if save_optm_state:
344-
h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,))
337+
if statesf is not None:
338+
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
345339
elif epoch_save:
346340
save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info)
341+
if statesf is not None:
342+
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
347343

348344
namin += 1
349345
if namin >= earlystop:
350346
if done_tokens > 0:
351347
optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer)
348+
lrsch.step()
352349
done_tokens = 0
353350
logger.info("early stop")
354351
break
@@ -368,10 +365,11 @@ def load_fixing(module):
368365

369366
if done_tokens > 0:
370367
optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer)
368+
lrsch.step()
371369

372370
save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info)
373-
if save_optm_state:
374-
h5save(optimizer.state_dict(), wkdir + "last.optm.h5")
371+
if statesf is not None:
372+
save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info)
375373
logger.info("model saved")
376374

377375
td.close()

0 commit comments

Comments
 (0)