|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import os
|
| 16 | +import gc |
16 | 17 | import time
|
17 | 18 | import yaml
|
18 | 19 | import json
|
@@ -124,6 +125,7 @@ def train(model,
|
124 | 125 | param.stop_gradient = True
|
125 | 126 |
|
126 | 127 | uniform_output_enabled = kwargs.pop("uniform_output_enabled", False)
|
| 128 | + export_during_train = kwargs.pop("export_during_train", False) |
127 | 129 | cli_args = kwargs.pop("cli_args", None)
|
128 | 130 | model.train()
|
129 | 131 | nranks = paddle.distributed.ParallelEnv().nranks
|
@@ -365,15 +367,17 @@ def train(model,
|
365 | 367 | os.path.join(current_save_dir, 'model.pdparams'))
|
366 | 368 | paddle.save(optimizer.state_dict(),
|
367 | 369 | os.path.join(current_save_dir, 'model.pdopt'))
|
368 |
| - if uniform_output_enabled: |
| 370 | + if export_during_train: |
369 | 371 | export(cli_args, model, current_save_dir)
|
| 372 | + gc.collect() |
370 | 373 |
|
371 | 374 | if use_ema:
|
372 | 375 | paddle.save(
|
373 | 376 | ema_model.state_dict(),
|
374 | 377 | os.path.join(current_save_dir, 'ema_model.pdparams'))
|
375 |
| - if uniform_output_enabled: |
| 378 | + if export_during_train: |
376 | 379 | export(cli_args, ema_model, current_save_dir, use_ema)
|
| 380 | + gc.collect() |
377 | 381 |
|
378 | 382 | save_models.append(current_save_dir)
|
379 | 383 | if len(save_models) > keep_checkpoint_max > 0:
|
@@ -403,8 +407,10 @@ def train(model,
|
403 | 407 | paddle.save(
|
404 | 408 | states_dict,
|
405 | 409 | os.path.join(best_model_dir, 'model.pdstates'))
|
406 |
| - if uniform_output_enabled: |
| 410 | + if export_during_train: |
407 | 411 | export(cli_args, model, best_model_dir)
|
| 412 | + gc.collect() |
| 413 | + if uniform_output_enabled: |
408 | 414 | save_model_info(states_dict, best_model_dir)
|
409 | 415 | update_train_results(cli_args,
|
410 | 416 | "best_model",
|
@@ -447,9 +453,11 @@ def train(model,
|
447 | 453 | ema_states_dict,
|
448 | 454 | os.path.join(best_ema_model_dir,
|
449 | 455 | 'ema_model.pdstates'))
|
450 |
| - if uniform_output_enabled: |
| 456 | + if export_during_train: |
451 | 457 | export(cli_args, ema_model, best_ema_model_dir,
|
452 | 458 | use_ema)
|
| 459 | + gc.collect() |
| 460 | + if uniform_output_enabled: |
453 | 461 | save_model_info(ema_states_dict,
|
454 | 462 | best_ema_model_dir)
|
455 | 463 | update_train_results(cli_args,
|
|
0 commit comments