-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexperiment.py
98 lines (81 loc) · 3.21 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# pylint: disable=missing-docstring,invalid-name,line-too-long, bare-except
import argparse
import json
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from keras.models import load_model
from preprocessing import get_datasets, datagen
from models import simple, simple_reg, deep, lenet, lenet_reg, deeper, mininception
from inception_v3 import inception_pretrained
model_map = {
'simple': simple,
'simple_reg': simple_reg,
'deep': deep,
'lenet': lenet,
'lenet_reg': lenet_reg,
'deeper': deeper,
'mininception': mininception,
'inception': inception_pretrained,
}
def train(args):
# Train given model
mm = model_map[args.model]()
# Callbacks for early stopping and tensorboard logging
callbacks = [
EarlyStopping(monitor='val_loss', patience=args.patience, verbose=1),
TensorBoard(f'output/{get_model_name(args)}_logs', write_graph=True),
ModelCheckpoint(filepath=get_model_path(args), save_best_only=True),
]
if args.augment <= 120000:
# Couldn't save a pickle bigger than that
X_train, y_train, X_valid, y_valid, _, _ = get_datasets(args.augment)
mm.fit(
x=X_train,
y=y_train,
batch_size=args.batch,
epochs=args.epochs,
validation_data=(X_valid, y_valid),
callbacks=callbacks)
else:
# Generate agumented dataset as the model is trained
X_train, y_train, X_valid, y_valid, _, _ = get_datasets()
mm.fit_generator(
generator=datagen.flow(X_train, y_train, batch_size=args.batch),
steps_per_epoch=args.augment // args.batch,
epochs=args.epochs,
validation_data=(X_valid, y_valid),
callbacks=callbacks)
# Reload best model
return load_model(get_model_path(args))
def evaluate(model, args):
""" Evaluates the test dataset """
X_test, y_test = get_datasets(args.augment)[4:]
test_loss, test_accuracy = model.evaluate(X_test, y_test)
prediction = model.predict(X_test)
return {
'test_accuracy': test_accuracy,
'test_loss': test_loss,
'y_true': y_test.tolist(),
'y_pred': prediction.tolist(),
}
def get_model_name(args):
return f'{args.model}_{args.augment}'
def get_model_path(args):
return f'output/{get_model_name(args)}.h5'
def get_model(args):
try:
return load_model(get_model_path(args))
except:
return train(args)
if __name__ == '__main__':
parser = argparse.ArgumentParser('CIFAR-10 experiments')
parser.add_argument('model', help='Which model(s) to run')
parser.add_argument('-a', '--augment', type=int, default=40000, help='Training dataset augmentation.')
parser.add_argument('-b', '--batch', type=int, default=128, help='Batch size.')
parser.add_argument('-e', '--epochs', type=int, default=1000, help='Number of epochs.')
parser.add_argument('-p', '--patience', type=int, default=5, help='Early Stopping patience.')
arguments = parser.parse_args()
model = get_model(arguments)
results = evaluate(model, arguments)
# Save results
with open(f'output/{get_model_name(arguments)}.json', 'w', encoding='utf-8') as f:
json.dump(results, f)