-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn2.py
85 lines (73 loc) · 2.37 KB
/
cnn2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
import cv2
import cnn_model
# 入力と出力を指定
im_rows = 32 # 画像の縦ピクセルサイズ
im_cols = 32 # 画像の横ピクセルサイズ
im_color = 3 # 画像の色空間
in_shape = (im_rows, im_cols, im_color)
nb_classes = 8
# 写真データを読み込み
photos = np.load('image/photos.npz')
x = photos['x']
y = photos['y']
# 読み込んだデータをの三次元配列に変換
x = x.reshape(-1, im_rows, im_cols, im_color)
x = x.astype('float32') / 255
# ラベルデータをone-hotベクトルに直す
y = tf.keras.utils.to_categorical(y.astype('int32'), nb_classes)
# 学習用とテスト用に分ける
x_train, x_test, y_train, y_test = train_test_split(
x, y, train_size=0.8)
# 学習用データを水増しする --- (*1)
x_new = []
y_new = []
for i, xi in enumerate(x_train):
yi = y_train[i]
for ang in range(-30, 30, 5):
# 回転させる --- (*2)
center = (16, 16) # 回転の中心点
mtx = cv2.getRotationMatrix2D(center, ang, 1.0)
xi2 = cv2.warpAffine(xi, mtx, (32, 32))
x_new.append(xi2)
y_new.append(yi)
# さらに左右反転させる --- (*3)
xi3 = cv2.flip(xi2, 1)
x_new.append(xi3)
y_new.append(yi)
# 水増しした画像を学習用に置き換える
print('水増し前=', len(y_train))
x_train = np.array(x_new)
y_train = np.array(y_new)
print('水増し後=', len(y_train))
# CNNモデルを取得 --- (*6)
model = cnn_model.get_model(in_shape, nb_classes)
# 学習を実行 --- (*8)
hist = model.fit(x_train, y_train,
batch_size=64,
epochs=20,
verbose=1,
validation_data=(x_test, y_test))
# モデルを評価 --- (*9)
score = model.evaluate(x_test, y_test, verbose=1)
print('正解率=', score[1], 'loss=', score[0])
# 学習の様子をグラフへ描画 --- (*10)
# 正解率の推移をプロット
"""
plt.plot(hist.history['accuracy'])
plt.plot(hist.history['val_accuracy'])
plt.title('Accuracy')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# ロスの推移をプロット
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Loss')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
"""
model.save_weights('./image/photos-model.hdf5')