-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtraining_with_K_fold.py
162 lines (125 loc) · 4.7 KB
/
training_with_K_fold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
# Train a new model starting from pre-trained weights
python3 training.py --dataset=/path/to/dataset --weight=/path/to/pretrained/weight.h5
# Resume training a model
python3 training.py --dataset=/path/to/dataset --continue_train=/path/to/latest/weights.h5
"""
import logging
import warnings
import os
logging.getLogger("tensorflow").setLevel(logging.ERROR)
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import sys
import json
import datetime
import numpy as np
import skimage.draw
import cv2
import matplotlib.pyplot as plt
import imgaug
# Root directory of the project
ROOT_DIR = os.getcwd()
from mrcnn.config import Config
from mrcnn import model as modellib, utils
from mrcnn import parse_args
import dataset
############################################################
# Args Configurations
############################################################
args = parse_args.parse_args()
# config parameter
pretrained_weight = os.path.join(ROOT_DIR, args.weight)
dataset_path = os.path.join(ROOT_DIR, args.dataset)
logs = os.path.join(ROOT_DIR, "logs")
if args.continue_train == "None":
continue_train = args.continue_train
else:
continue_train = os.path.join(ROOT_DIR, args.continue_train)
############################################################
# Configurations
############################################################
class CustomConfig(Config):
NAME = "custom_dataset"
IMAGES_PER_GPU = 1
IMAGE_MAX_DIM = 512
NUM_CLASSES = 1 + 4
STEPS_PER_EPOCH = 750
VALIDATION_STEPS = 250
DETECTION_MIN_CONFIDENCE = 0.9
LEARNING_RATE = 0.001
DETECTION_NMS_THRESHOLD = 0.2
TRAIN_ROIS_PER_IMAGE = 200
MAX_GT_INSTANCES = 50
DETECTION_MAX_INSTANCES = 50
############################################################
# Training
############################################################
def train(model):
"""Train the model."""
epoch_count = 0
# training cross-validation with 5 fold
for i in range(5):
# Training dataset.
print("Training fold", i)
dataset_train = dataset.CustomDataset()
dataset_train.load_custom_K_fold(dataset_path, "train", i)
dataset_train.prepare()
# Validation dataset
dataset_val = dataset.CustomDataset()
dataset_val.load_custom_K_fold(dataset_path, "val", i)
dataset_val.prepare()
augmentation = imgaug.augmenters.Sometimes(0.5, [
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.Flipud(0.5)])
model_inference = modellib.MaskRCNN(mode="inference", config=config,model_dir=logs)
mAP_callback = modellib.MeanAveragePrecisionCallback(model, model_inference, dataset_val,
calculate_at_every_X_epoch=25, dataset_limit=500, verbose=1)
# Training - Stage 1
epoch_count += 20
print("Training network heads")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE *2,
epochs= epoch_count,
layers='heads',
custom_callbacks=[mAP_callback])
#augmentation=augmentation)
epoch_count += 10
print("Fine tune Resnet stage 4 and up")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs= epoch_count,
layers='4+',
custom_callbacks=[mAP_callback],
augmentation=augmentation)
############################################################
# Training
############################################################
if __name__ == '__main__':
print("Pre-trained weight: ", pretrained_weight)
print("Dataset: ", dataset_path)
print("Logs: ", logs)
print("Continue Train: ", continue_train)
# Configurations
config = CustomConfig()
config.display()
# Create model
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=logs)
if continue_train.lower() == "none":
weights_path = pretrained_weight
else:
weights_path = continue_train
# Load weights
print("Loading weights ", weights_path)
if continue_train == "None":
# Exclude the last layers because they require a matching
# number of classes
model.load_weights(weights_path, by_name=True, exclude=[
"mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
else:
model.load_weights(weights_path, by_name=True)
train(model)