From bcf9c27155dfcf3db2f2644ca4eb04500257cdd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20=C5=BBak?= Date: Thu, 1 Feb 2024 11:54:47 +0100 Subject: [PATCH 1/4] Update train.py Add train from checkpoint capability --- bin/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bin/train.py b/bin/train.py index 96eb9b84..69c69b2e 100755 --- a/bin/train.py +++ b/bin/train.py @@ -18,7 +18,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.plugins import DDPPlugin -from saicinpainting.training.trainers import make_training_model +from saicinpainting.training.trainers import make_training_model, load_checkpoint from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \ handle_deterministic_config @@ -48,7 +48,15 @@ def main(config: OmegaConf): metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd())) metrics_logger.log_hyperparams(config) - training_model = make_training_model(config) + if "load_checkpoint_path" in config.location: + print("Loading model checkpoint from path:", config.location.load_checkpoint_path) + training_model = load_checkpoint( + train_config=config, + path=config.location.load_checkpoint_path, + strict=False + ) + else: + training_model = make_training_model(config) trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True) if need_set_deterministic: From 630948e8ea25d7f3417befa94ba9b78df861747e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20=C5=BBak?= Date: Thu, 1 Feb 2024 12:00:43 +0100 Subject: [PATCH 2/4] Update places_example.yaml --- configs/training/location/places_example.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/training/location/places_example.yaml b/configs/training/location/places_example.yaml index 97a9f9b5..1d01f03f 100644 --- a/configs/training/location/places_example.yaml +++ b/configs/training/location/places_example.yaml @@ -3,3 +3,5 @@ data_root_dir: /home/user/inpainting-lama/places_standard_dataset/ out_root_dir: /home/user/inpainting-lama/experiments tb_dir: /home/user/inpainting-lama/tb_logs pretrained_models: /home/user/inpainting-lama/ +# path to model checkpoint that will be loaded before start of training +# load_checkpoint_path: /home/user/lama/big-lama/models/best.ckpt From 448ea11c42adc7cca4ab2a5848ea7d9fd99bbc9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20=C5=BBak?= Date: Thu, 1 Feb 2024 12:01:00 +0100 Subject: [PATCH 3/4] Update docker.yaml --- configs/training/location/docker.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/training/location/docker.yaml b/configs/training/location/docker.yaml index 5da6a4a4..30a14404 100644 --- a/configs/training/location/docker.yaml +++ b/configs/training/location/docker.yaml @@ -3,3 +3,5 @@ data_root_dir: /data/data out_root_dir: /data/experiments tb_dir: /data/tb_logs pretrained_models: /some_path +# path to model checkpoint that will be loaded before start of training +# load_checkpoint_path: /home/user/lama/big-lama/models/best.ckpt From 167c3164d57f5c97416b8b64134c6aac78809c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20=C5=BBak?= Date: Thu, 1 Feb 2024 12:01:08 +0100 Subject: [PATCH 4/4] Update celeba_example.yaml --- configs/training/location/celeba_example.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/training/location/celeba_example.yaml b/configs/training/location/celeba_example.yaml index 117fe8a9..77d72813 100644 --- a/configs/training/location/celeba_example.yaml +++ b/configs/training/location/celeba_example.yaml @@ -3,3 +3,5 @@ data_root_dir: /home/user/lama/celeba-hq-dataset/ out_root_dir: /home/user/lama/experiments/ tb_dir: /home/user/lama/tb_logs/ pretrained_models: /home/user/lama/ +# path to model checkpoint that will be loaded before start of training +# load_checkpoint_path: /home/user/lama/big-lama/models/best.ckpt