forked from watsonyanghx/CNN_LSTM_CTC_Tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_model.py
93 lines (63 loc) · 2.8 KB
/
eval_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import datetime
import logging
import os
import time
import cv2
import numpy as np
import tensorflow as tf
import cnn_lstm_otc_ocr
import utils
import helper
from preparedata import PrepareData
FLAGS = utils.FLAGS
import math
import argparse
log_dir = './log/evals'
class EvaluateModel(PrepareData):
def __init__(self):
PrepareData.__init__(self)
return
def parse_param(self):
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--split_name', help='which split of dataset to use', default="eval")
parser.add_argument('-c', '--checkpoint_path', help='which checkpoint to use', default= "./checkpoint/")
args = parser.parse_args()
self.checkpoint_path = args.checkpoint_path
self.split_name = args.split_name
return
def eval_model(self):
model = cnn_lstm_otc_ocr.LSTMOCR('eval')
model.build_graph()
val_feeder, num_samples = self.input_batch_generator(self.split_name, is_training=False, batch_size = FLAGS.batch_size)
num_batches_per_epoch = int(math.ceil(num_samples / float(FLAGS.batch_size)))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
eval_writer = tf.summary.FileWriter("{}/{}".format(log_dir, self.split_name), sess.graph)
if tf.gfile.IsDirectory(self.checkpoint_path):
checkpoint_file = tf.train.latest_checkpoint(self.checkpoint_path)
else:
checkpoint_file = self.checkpoint_path
print('Evaluating checkpoint_path={}, split={}, num_samples={}'.format(checkpoint_file, self.split_name, num_samples))
saver.restore(sess, checkpoint_file)
for i in range(num_batches_per_epoch):
inputs, labels, _ = next(val_feeder)
feed = {model.inputs: inputs,
model.labels: labels}
start = time.time()
_ = sess.run(model.names_to_updates, feed)
elapsed = time.time()
elapsed = elapsed - start
# print('{}/{}, {:.5f} seconds.'.format(i, num_batches_per_epoch, elapsed))
# print the decode result
summary_str, step = sess.run([model.merged_summay, model.global_step])
eval_writer.add_summary(summary_str, step)
return
def run(self):
self.parse_param()
self.eval_model()
return
if __name__ == "__main__":
obj= EvaluateModel()
obj.run()