-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdownstream.py
108 lines (103 loc) · 4.62 KB
/
downstream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import gc
import argparse
import MinkowskiEngine as ME
import pytorch_lightning as pl
from downstream.evaluate import evaluate
from utils.read_config import generate_config
from downstream.model_builder import make_model
from pytorch_lightning.plugins import DDPPlugin
from downstream.lightning_trainer import LightningDownstream
from downstream.lightning_datamodule import DownstreamDataModule
from downstream.dataloader_kitti import make_data_loader as make_data_loader_kitti
from downstream.dataloader_nuscenes import make_data_loader as make_data_loader_nuscenes
from downstream.dataloader_scribble_kitti import make_data_loader as make_data_loader_scribble_kitti
from downstream.dataloader_rellis3d import make_data_loader as make_data_loader_rellis3d
# from downstream.dataloader_semanticposs import make_data_loader as make_data_loader_semanticposs
from downstream.dataloader_semanticstf import make_data_loader as make_data_loader_semanticstf
# from downstream.dataloader_synlidar import make_data_loader as make_data_loader_synlidar
from downstream.dataloader_daps3d import make_data_loader as make_data_loader_daps3d
def main():
"""
Code for launching the downstream training
"""
parser = argparse.ArgumentParser(description="arg parser")
parser.add_argument(
"--cfg_file", type=str, default="config/semseg_nuscenes.yaml", help="specify the config for training"
)
parser.add_argument(
"--resume_path", type=str, default=None, help="provide a path to resume an incomplete training"
)
parser.add_argument(
"--pretraining_path", type=str, default=None, help="provide a path to pre-trained weights"
)
args = parser.parse_args()
config = generate_config(args.cfg_file)
if args.resume_path:
config['resume_path'] = args.resume_path
if args.pretraining_path:
config['pretraining_path'] = args.pretraining_path
if os.environ.get("LOCAL_RANK", 0) == 0:
print(
"\n" + "\n".join(list(map(lambda x: f"{x[0]:20}: {x[1]}", config.items())))
)
dm = DownstreamDataModule(config)
model = make_model(config, config["pretraining_path"])
if config["num_gpus"] > 1:
model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model)
module = LightningDownstream(model, config)
path = os.path.join(config["working_dir"], config["datetime"])
trainer = pl.Trainer(
gpus=config["num_gpus"],
accelerator="ddp",
default_root_dir=path,
checkpoint_callback=True,
max_epochs=config["num_epochs"],
plugins=DDPPlugin(find_unused_parameters=False),
num_sanity_val_steps=0,
resume_from_checkpoint=config["resume_path"],
check_val_every_n_epoch=1,
)
print("Starting the training")
trainer.fit(module, dm)
print("Training finished, now evaluating the results")
del trainer
del dm
del module
gc.collect()
if config["dataset"].lower() == "nuscenes":
phase = "verifying" if config['training'] in ("parametrize", "parametrizing") else "val"
val_dataloader = make_data_loader_nuscenes(
config, phase, num_threads=config["num_threads"]
)
elif config["dataset"].lower() == "kitti":
val_dataloader = make_data_loader_kitti(
config, "val", num_threads=config["num_threads"]
)
elif config["dataset"].lower() == "scribble_kitti":
val_dataloader = make_data_loader_scribble_kitti(
config, "val", num_threads=config["num_threads"]
)
elif config["dataset"].lower() == "rellis3d":
val_dataloader = make_data_loader_rellis3d(
config, "val", num_threads=config["num_threads"]
)
# elif config["dataset"].lower() in ["semantic_poss", "semanticposs"]:
# val_dataloader = make_data_loader_semanticposs(
# config, "val", num_threads=config["num_threads"]
# )
elif config["dataset"].lower() in ["semantic_stf", "semanticstf"]:
val_dataloader = make_data_loader_semanticstf(
config, "val", num_threads=config["num_threads"]
)
# elif config["dataset"].lower() in ["synlidar"]:
# val_dataloader = make_data_loader_synlidar(
# config, "val", num_threads=config["num_threads"]
# )
elif config["dataset"].lower() in ["daps3d"]:
val_dataloader = make_data_loader_daps3d(
config, "val", num_threads=config["num_threads"]
)
evaluate(model.to(0), val_dataloader, config)
if __name__ == "__main__":
main()