-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathrun_experiment.py
155 lines (125 loc) · 4.37 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# general
import os
import random
import sys
import copy
import time
# torch
import numpy as np
import torch
# project
from path_handler import model_path
from model_constructor import construct_model
from dataset_constructor import construct_dataloaders
import trainer
import tester
import ckconv
# Loggers and config
import wandb
import hydra
from omegaconf import OmegaConf
from hydra import utils
import torchinfo
def setup(cfg):
# With -1 seed, we use a determined and stored random seed instead
if cfg.seed == -1:
cfg.seed = np.random.randint(0, 100)
# Set the seed
set_manual_seed(cfg.seed, cfg.deterministic)
# Initialize wandb
if not cfg.train or cfg.debug:
os.environ["WANDB_MODE"] = "dryrun"
os.environ["HYDRA_FULL_ERROR"] = "1"
wandb.init(
project=cfg.wandb.project,
config=ckconv.utils.flatten_configdict(cfg),
entity=cfg.wandb.entity,
save_code=True,
dir=cfg.wandb.dir,
)
def model_and_datasets(cfg):
# Construct the model
model = construct_model(cfg)
# Send model to GPU if available, otherwise to CPU
# Check if multi-GPU available and if so, use the available GPU's
print("GPU's available:", torch.cuda.device_count())
if cfg.device == "cuda" and torch.cuda.is_available():
print(f"Let's use {torch.cuda.device_count()} GPUs!")
# Set device and send model to device
cfg.device = "cuda"
model.to(cfg.device)
# Construct dataloaders
dataloaders = construct_dataloaders(cfg)
# # WandB – wandb.watch() automatically fetches all layer dimensions, gradients, model parameters and logs them automatically to your dashboard.
# # Using log="all" log histograms of parameter values in addition to gradients
# wandb.watch(model, log="all", log_freq=200)
# Create model directory and instantiate config.path
# model_path(cfg) # TODO
if cfg.pretrained:
# Load model state dict
missing, unexpected = model.module.load_state_dict(
torch.load(cfg.pretrained_params.filepath, map_location=cfg.device)[
"model"
],
strict=cfg.pretrained_strict,
)
print("Loaded model.")
elif cfg.pretrained_wandb:
# Load model state dict from wandb
weights_file = wandb.restore(
cfg.pretrained_wandb_params.filename,
run_path=cfg.pretrained_wandb_params.run_path,
)
missing, unexpected = model.module.load_state_dict(
torch.load(weights_file.name, map_location=cfg.device)["model"],
strict=cfg.pretrained_strict,
)
print("Loaded model from W&B.")
if cfg.pretrained or cfg.pretrained_wandb:
if len(missing) > 0:
print("Missing keys:\n" + "\n".join(missing))
if len(unexpected) > 0:
print("Unexpected keys:\n" + "\n".join(unexpected))
# Clear train lengths
for m in model.modules():
if isinstance(m, ckconv.nn.CKConv):
m.train_length[0] = 0
if len(cfg.summary) > 1:
torchinfo.summary(model, tuple(cfg.summary), depth=cfg.summary_depth)
return model, dataloaders
@hydra.main(config_path="cfg", config_name="config.yaml")
def main(
cfg: OmegaConf,
):
# We possibly want to add fields to the config file. Thus, we set struct to False.
OmegaConf.set_struct(cfg, False)
# Print input args
print(f"Input arguments \n {OmegaConf.to_yaml(cfg)}")
setup(cfg)
model, dataloaders = model_and_datasets(cfg)
if cfg.test.before_train:
tester.test(model, dataloaders["test"], cfg, log=True, epoch=0)
# Train the model
if cfg.train.do:
# Print arguments (Sanity check)
print(f"Modified arguments: \n {OmegaConf.to_yaml(cfg)}")
# Train the model
trainer.train(model, dataloaders, cfg)
# Select test function
tester.test(model, dataloaders["test"], cfg)
def set_manual_seed(
seed: int,
deterministic: bool,
):
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = True
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == "__main__":
main()