-
Notifications
You must be signed in to change notification settings - Fork 479
/
Copy pathvalidate_rnn.py
54 lines (44 loc) · 1.45 KB
/
validate_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, CSVLogger
from models import ResearchModels
from data import DataSet
def validate(data_type, model, seq_length=40, saved_model=None,
class_limit=None, image_shape=None):
batch_size = 32
# Get the data and process it.
if image_shape is None:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit
)
else:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit,
image_shape=image_shape
)
val_generator = data.frame_generator(batch_size, 'test', data_type)
# Get the model.
rm = ResearchModels(len(data.classes), model, seq_length, saved_model)
# Evaluate!
results = rm.model.evaluate_generator(
generator=val_generator,
val_samples=3200)
print(results)
print(rm.model.metrics_names)
def main():
model = 'lstm'
saved_model = 'data/checkpoints/lstm-features.026-0.239.hdf5'
if model == 'conv_3d' or model == 'lrcn':
data_type = 'images'
image_shape = (80, 80, 3)
else:
data_type = 'features'
image_shape = None
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=4)
if __name__ == '__main__':
main()