Skip to content

Commit 4b61544

Browse files
committed
separation parameter
1 parent 3b15f91 commit 4b61544

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

paddleseg/core/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ def export(args, model=None, save_dir=None, use_ema=False):
5454
input_spec = [paddle.static.InputSpec(shape=shape, dtype='float32')]
5555
model.eval()
5656
model = paddle.jit.to_static(model, input_spec=input_spec)
57-
uniform_output_enabled = cfg.dic.get('uniform_output_enabled', False)
58-
if args.for_fd or uniform_output_enabled:
57+
export_during_train = cfg.dic.get('export_during_train', False)
58+
if args.for_fd or export_during_train:
5959
save_name = 'inference'
6060
yaml_name = 'inference.yml'
6161
else:
6262
save_name = 'model'
6363
yaml_name = 'deploy.yaml'
6464

65-
if uniform_output_enabled:
65+
if export_during_train:
6666
inference_model_path = os.path.join(save_dir, "inference", save_name)
6767
yml_file = os.path.join(save_dir, "inference", yaml_name)
6868
if use_ema:

paddleseg/core/train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import gc
1617
import time
1718
import yaml
1819
import json
@@ -124,6 +125,7 @@ def train(model,
124125
param.stop_gradient = True
125126

126127
uniform_output_enabled = kwargs.pop("uniform_output_enabled", False)
128+
export_during_train = kwargs.pop("export_during_train", False)
127129
cli_args = kwargs.pop("cli_args", None)
128130
model.train()
129131
nranks = paddle.distributed.ParallelEnv().nranks
@@ -365,15 +367,17 @@ def train(model,
365367
os.path.join(current_save_dir, 'model.pdparams'))
366368
paddle.save(optimizer.state_dict(),
367369
os.path.join(current_save_dir, 'model.pdopt'))
368-
if uniform_output_enabled:
370+
if export_during_train:
369371
export(cli_args, model, current_save_dir)
372+
gc.collect()
370373

371374
if use_ema:
372375
paddle.save(
373376
ema_model.state_dict(),
374377
os.path.join(current_save_dir, 'ema_model.pdparams'))
375-
if uniform_output_enabled:
378+
if export_during_train:
376379
export(cli_args, ema_model, current_save_dir, use_ema)
380+
gc.collect()
377381

378382
save_models.append(current_save_dir)
379383
if len(save_models) > keep_checkpoint_max > 0:
@@ -403,8 +407,10 @@ def train(model,
403407
paddle.save(
404408
states_dict,
405409
os.path.join(best_model_dir, 'model.pdstates'))
406-
if uniform_output_enabled:
410+
if export_during_train:
407411
export(cli_args, model, best_model_dir)
412+
gc.collect()
413+
if uniform_output_enabled:
408414
save_model_info(states_dict, best_model_dir)
409415
update_train_results(cli_args,
410416
"best_model",
@@ -447,9 +453,11 @@ def train(model,
447453
ema_states_dict,
448454
os.path.join(best_ema_model_dir,
449455
'ema_model.pdstates'))
450-
if uniform_output_enabled:
456+
if export_during_train:
451457
export(cli_args, ema_model, best_ema_model_dir,
452458
use_ema)
459+
gc.collect()
460+
if uniform_output_enabled:
453461
save_model_info(ema_states_dict,
454462
best_ema_model_dir)
455463
update_train_results(cli_args,

tools/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def main(args):
180180
utils.set_device(args.device)
181181
utils.set_cv2_num_threads(args.num_workers)
182182
uniform_output_enabled = cfg.dic.get("uniform_output_enabled", False)
183+
export_during_train = cfg.dic.get("export_during_train", False)
183184
if uniform_output_enabled:
184185
if not os.path.exists(args.save_dir):
185186
os.makedirs(args.save_dir)
@@ -244,7 +245,8 @@ def main(args):
244245
print_mem_info=print_mem_info,
245246
shuffle=shuffle,
246247
uniform_output_enabled=uniform_output_enabled,
247-
cli_args=None if not uniform_output_enabled else args)
248+
export_during_train=export_during_train,
249+
cli_args=None if not export_during_train else args)
248250

249251

250252
if __name__ == '__main__':

0 commit comments

Comments
 (0)