diff --git a/bin/train.py b/bin/train.py index 96eb9b84..69c69b2e 100755 --- a/bin/train.py +++ b/bin/train.py @@ -18,7 +18,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.plugins import DDPPlugin -from saicinpainting.training.trainers import make_training_model +from saicinpainting.training.trainers import make_training_model, load_checkpoint from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \ handle_deterministic_config @@ -48,7 +48,15 @@ def main(config: OmegaConf): metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd())) metrics_logger.log_hyperparams(config) - training_model = make_training_model(config) + if "load_checkpoint_path" in config.location: + print("Loading model checkpoint from path:", config.location.load_checkpoint_path) + training_model = load_checkpoint( + train_config=config, + path=config.location.load_checkpoint_path, + strict=False + ) + else: + training_model = make_training_model(config) trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True) if need_set_deterministic: diff --git a/configs/training/location/celeba_example.yaml b/configs/training/location/celeba_example.yaml index 117fe8a9..77d72813 100644 --- a/configs/training/location/celeba_example.yaml +++ b/configs/training/location/celeba_example.yaml @@ -3,3 +3,5 @@ data_root_dir: /home/user/lama/celeba-hq-dataset/ out_root_dir: /home/user/lama/experiments/ tb_dir: /home/user/lama/tb_logs/ pretrained_models: /home/user/lama/ +# path to model checkpoint that will be loaded before start of training +# load_checkpoint_path: /home/user/lama/big-lama/models/best.ckpt diff --git a/configs/training/location/docker.yaml b/configs/training/location/docker.yaml index 5da6a4a4..30a14404 100644 --- a/configs/training/location/docker.yaml +++ b/configs/training/location/docker.yaml @@ -3,3 +3,5 @@ data_root_dir: /data/data out_root_dir: /data/experiments tb_dir: /data/tb_logs pretrained_models: /some_path +# path to model checkpoint that will be loaded before start of training +# load_checkpoint_path: /home/user/lama/big-lama/models/best.ckpt diff --git a/configs/training/location/places_example.yaml b/configs/training/location/places_example.yaml index 97a9f9b5..1d01f03f 100644 --- a/configs/training/location/places_example.yaml +++ b/configs/training/location/places_example.yaml @@ -3,3 +3,5 @@ data_root_dir: /home/user/inpainting-lama/places_standard_dataset/ out_root_dir: /home/user/inpainting-lama/experiments tb_dir: /home/user/inpainting-lama/tb_logs pretrained_models: /home/user/inpainting-lama/ +# path to model checkpoint that will be loaded before start of training +# load_checkpoint_path: /home/user/lama/big-lama/models/best.ckpt