-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcifar10_train.py
57 lines (41 loc) · 1.74 KB
/
cifar10_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Train CIFAR-10 using a single CPU."""
import numpy as np
import matplotlib.pyplot as plt
import cifar10
import cifar10_eval
learning_rate = 0.01 # Step length of gradient descent
num_iterations = 1 # Number of iterations
def train():
"""Train CIFAR-10 for a number of steps."""
# Get training set
mini_batches = cifar10.preprocessing_inputs()
# Training
parameters, costs, bn_param = cifar10.inference(mini_batches, learning_rate=learning_rate,
num_iterations=num_iterations)
# Get training and testing images and labels
minibatch_test = cifar10.preprocessing_inputs_test()
num_minibatch_test = len(minibatch_test)
minibatch_train = cifar10.preprocessing_inputs()
num_minibatch_train = len(minibatch_train)
# Predict test/train set examples
Y_prediction_test = cifar10_eval.predict(parameters, minibatch_test, bn_param)
Y_prediction_train = cifar10_eval.predict(parameters, minibatch_train, bn_param)
# Print train/test Errors
train_right = 0.; test_right = 0.
for i in range(num_minibatch_train):
for j in range(128):
if Y_prediction_train[i][j] == minibatch_train[i][1][0][j]: train_right += 1
for i in range(num_minibatch_test):
for j in range(128):
if Y_prediction_test[i][j] == minibatch_test[i][1][0][j]: test_right += 1
print("train accuracy: {} %".format(train_right/100))
print("test accuracy: {} %".format(test_right/100))
# Plot cost function iteration image
costs = np.squeeze(costs)
plt.plot(costs)
plt.ylabel('cost')
plt.xlabel('iterations (per hundreds)')
plt.title("Learning rate =" + str(learning_rate))
plt.show()
if __name__ == '__main__':
train()