Skip to content

Commit fa7c19c

Browse files
authored
Merge pull request #1 from HernandoR/dev/optim_model
Dev/optim model
2 parents 4f9cfb8 + eb457cd commit fa7c19c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+9772
-1407
lines changed

.gitignore

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,24 @@ ENV/
100100
# mypy
101101
.mypy_cache/
102102

103+
103104
# input data, saved log, checkpoints
104-
data/
105-
wandb/
106-
input/
107-
output/
105+
data/*
106+
wandb/*
107+
input/*
108+
output/*
108109
saved/
109-
datasets/
110+
datasets/*
111+
112+
113+
114+
model/*/*
115+
110116

111117
# editor, os cache directory
112-
.vscode/
113-
.idea/
114-
__MACOSX/
118+
.vscode/*
119+
.idea/*
120+
__MACOSX/*
115121

116122
# personal config files
117123
*.cfg

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The submitable (or newest) version is in the root folder / "[baseline.ipynb](dev
2222

2323
following is the instruction of the template.
2424

25+
26+
Data URL https://gist.github.com/nat/e7266a5c765686b7976df10d3a85041b
2527
------
2628

2729

@@ -79,7 +81,7 @@ Try `python train.py -c config.json` to run code.
7981

8082
Config files are in `.json` format:
8183

82-
```javascript
84+
```json5
8385
{
8486
"name": "Mnist_LeNet", // training session name
8587
"n_gpu": 1, // number of GPUs to use for training.
@@ -96,7 +98,7 @@ Config files are in `.json` format:
9698
"data_dir": "data/", // dataset path
9799
"batch_size": 64, // batch size
98100
"shuffle": true, // shuffle training data before splitting
99-
"validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples)
101+
"validation_split": 0.1, // size of validation dataset. float(portion) or int(number of samples)
100102
"num_workers": 2, // number of cpu processes to be used for data loading
101103
}
102104
},
@@ -125,8 +127,8 @@ Config files are in `.json` format:
125127
"save_freq": 1, // save checkpoints every save_freq epochs
126128
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
127129

128-
"monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
129-
"early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
130+
"monitor": "min val_loss", // mode and metric for model performance monitoring. set 'off' to disable.
131+
"early_stop": 10, // number of epochs to wait before early stop. set 0 to disable.
130132

131133
"tensorboard": true, // enable tensorboard visualization
132134
}

base/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base_data_loader import *
2+
from .base_model import *
3+
from .base_trainer import *

base/base_data_loader.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
from torch.utils.data import DataLoader
3+
from torch.utils.data.dataloader import default_collate
4+
from torch.utils.data.sampler import SubsetRandomSampler
5+
6+
from logger import Loggers
7+
8+
Logger = Loggers.get_logger(__name__)
9+
10+
11+
class BaseDataLoader(DataLoader):
12+
"""
13+
Base class for all data loaders
14+
"""
15+
16+
def __init__(self, dataset, batch_size: int,
17+
shuffle: bool, validation_split: int,
18+
num_workers: int, collate_fn=default_collate, **kwargs):
19+
self.validation_split = validation_split
20+
self.shuffle = shuffle
21+
22+
self.batch_idx = 0
23+
self.n_samples = len(dataset)
24+
25+
self.sampler, self.valid_sampler = self._split_sampler(
26+
self.validation_split)
27+
28+
self.init_kwargs = {
29+
'dataset': dataset,
30+
'batch_size': batch_size,
31+
'shuffle': self.shuffle,
32+
'collate_fn': collate_fn,
33+
'num_workers': num_workers
34+
}
35+
self.init_kwargs.update(kwargs)
36+
super().__init__(sampler=self.sampler, **self.init_kwargs)
37+
38+
def _split_sampler(self, split):
39+
if split == 0.0:
40+
return None, None
41+
42+
idx_full = np.arange(self.n_samples)
43+
44+
np.random.seed(0)
45+
np.random.shuffle(idx_full)
46+
47+
if isinstance(split, int):
48+
assert split > 0
49+
assert split < self.n_samples, ...
50+
"validation set size is configured to be larger than entire dataset."
51+
if split < 1.0:
52+
split = int(split * self.n_samples)
53+
Logger.info(
54+
f"got an fraction number for validation split, convert to {split} samples")
55+
len_valid = split
56+
else:
57+
len_valid = int(self.n_samples * split)
58+
59+
valid_idx = idx_full[0:len_valid]
60+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
61+
62+
train_sampler = SubsetRandomSampler(train_idx)
63+
valid_sampler = SubsetRandomSampler(valid_idx)
64+
65+
# turn off shuffle option which is mutually exclusive with sampler
66+
self.shuffle = False
67+
self.n_samples = len(train_idx)
68+
69+
return train_sampler, valid_sampler
70+
71+
def split_validation(self):
72+
if self.valid_sampler is None:
73+
return None
74+
else:
75+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
76+
77+

base/base_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch.nn as nn
2+
import numpy as np
3+
from abc import abstractmethod
4+
5+
6+
class BaseModel(nn.Module):
7+
"""
8+
Base class for all models
9+
"""
10+
11+
@abstractmethod
12+
def forward(self, *inputs):
13+
"""
14+
Forward pass logic
15+
16+
:return: Model output
17+
"""
18+
raise NotImplementedError
19+
20+
def __str__(self):
21+
"""
22+
Model prints with number of trainable parameters
23+
"""
24+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
25+
params = sum([np.prod(p.size()) for p in model_parameters])
26+
return super().__str__() + '\nTrainable parameters: {}'.format(params)

base/base_trainer.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from abc import abstractmethod
2+
from pathlib import Path
3+
4+
import torch
5+
import wandb
6+
from numpy import inf
7+
8+
from logger import Loggers
9+
10+
11+
class BaseTrainer:
12+
"""
13+
Base class for all trainers
14+
"""
15+
16+
def __init__(self, model, criterion, metric_ftns, optimizer, config):
17+
self.config = config
18+
self.logger = Loggers.get_logger('trainer')
19+
self.model = model
20+
self.model_id = config.model_id
21+
self.criterion = criterion
22+
self.metric_ftns = metric_ftns
23+
self.optimizer = optimizer
24+
25+
cfg_trainer = config['trainer']
26+
self.epochs = cfg_trainer['epochs']
27+
self.save_period = cfg_trainer['save_period']
28+
self.monitor = cfg_trainer.get('monitor', 'off')
29+
30+
# configuration to monitor model performance and save best
31+
if self.monitor == 'off':
32+
self.mnt_mode = 'off'
33+
self.mnt_best = 0
34+
else:
35+
self.mnt_mode, self.mnt_metric = self.monitor.split()
36+
self.mnt_mode = self.mnt_mode.lower()
37+
assert self.mnt_mode in ['min', 'max']
38+
39+
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
40+
self.early_stop = cfg_trainer.get('early_stop', inf)
41+
if self.early_stop <= 0:
42+
self.early_stop = inf
43+
44+
self.start_epoch = 1
45+
46+
self.checkpoint_dir = Path(config['PATHS']['CP_DIR'])
47+
48+
if config.resume_path is not None:
49+
self._resume_checkpoint(config.resume_path)
50+
51+
@abstractmethod
52+
def _train_epoch(self, epoch):
53+
"""
54+
Training logic for an epoch
55+
56+
:param epoch: Current epoch number
57+
"""
58+
raise NotImplementedError
59+
60+
def train(self):
61+
"""
62+
Full training logic
63+
"""
64+
65+
not_improved_count = 0
66+
for epoch in range(self.start_epoch, self.epochs + 1):
67+
# train epoch
68+
# return metrics that may or may not be logged
69+
# TODO find how to config the mnt_metric
70+
result = self._train_epoch(epoch)
71+
72+
# save logged information into log dict
73+
log = {'epoch': epoch}
74+
log.update(result)
75+
76+
# print logged information to the screen
77+
for key, value in log.items():
78+
# self.logger.info(' {:15s}: {}'.format(str(key), value))
79+
self.logger.info(f" {key:15s}: {value}")
80+
81+
# evaluate model performance according to configured metric, save the best checkpoint as model_best
82+
best = False
83+
if self.mnt_mode != 'off':
84+
try:
85+
# check whether model performance improved or not, according to specified metric(mnt_metric)
86+
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
87+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
88+
except KeyError:
89+
self.logger.warning("Warning: Metric '{}' is not found. "
90+
"Model performance monitoring is disabled.".format(self.mnt_metric))
91+
self.mnt_mode = 'off'
92+
improved = False
93+
94+
if improved:
95+
self.mnt_best = log[self.mnt_metric]
96+
not_improved_count = 0
97+
best = True
98+
else:
99+
not_improved_count += 1
100+
self.logger.info("Early stop count: {}".format(not_improved_count))
101+
102+
if not_improved_count > self.early_stop:
103+
self.logger.info("Validation performance didn\'t improve for {} epochs. "
104+
"Training stops.".format(self.early_stop))
105+
wandb.run.summary["early_stop"] = True
106+
107+
break
108+
109+
if epoch % self.save_period == 0:
110+
self._save_checkpoint(epoch, save_best=best)
111+
112+
wandb.save(str(self.checkpoint_dir / f'{self.model_id}_best.pth'))
113+
wandb.finish()
114+
115+
def _save_checkpoint(self, epoch, save_best=False):
116+
"""
117+
Saving checkpoints
118+
119+
:param epoch: current epoch number
120+
:param log: logging information of the epoch
121+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
122+
"""
123+
arch = type(self.model).__name__
124+
state = {
125+
'arch': arch,
126+
'epoch': epoch,
127+
'model': self.model.state_dict(),
128+
'optimizer': self.optimizer.state_dict(),
129+
'monitor_best': self.mnt_best,
130+
# 'config': self.config
131+
'config': {
132+
k: v for k, v in self.config.items() if k in ['model', 'optimizer', 'trainer']
133+
}
134+
}
135+
filename = str(self.checkpoint_dir / f'{self.model_id}_checkpoints.pth')
136+
torch.save(state, filename)
137+
self.logger.info("Saving checkpoint: {} ...".format(filename))
138+
if save_best:
139+
best_path = str(self.checkpoint_dir / f'{self.model_id}_best.pth')
140+
torch.save(state, best_path)
141+
self.logger.info("Saving current best: model_best.pth ...")
142+
143+
def _resume_checkpoint(self, resume_path):
144+
"""
145+
Resume from saved checkpoints
146+
147+
:param resume_path: Checkpoint path to be resumed
148+
"""
149+
resume_path = str(resume_path)
150+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
151+
checkpoint = torch.load(resume_path)
152+
self.start_epoch = checkpoint['epoch'] + 1
153+
self.mnt_best = checkpoint['monitor_best']
154+
155+
# load architecture params from checkpoint.
156+
if checkpoint['config']['model'] != self.config['arch']:
157+
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
158+
"checkpoint. This may yield an exception while state_dict is being loaded.")
159+
self.model.load_state_dict(checkpoint['model'])
160+
161+
# load optimizer state from checkpoint only when optimizer type is not changed.
162+
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
163+
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
164+
"Optimizer parameters not being resumed.")
165+
else:
166+
self.optimizer.load_state_dict(checkpoint['optimizer'])
167+
168+
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))

0 commit comments

Comments
 (0)