forked from PINTO0309/DeepLearningMugenKnock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bin_dataset_tensorflow_slim.py
103 lines (77 loc) · 2.72 KB
/
bin_dataset_tensorflow_slim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.contrib import slim
import argparse
import cv2
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
num_classes = 2
img_height, img_width = 64, 64#572, 572
out_height, out_width = 64, 64#388, 388
CLS = {'akahara': [0,0,128],
'madara': [0,128,0]}
# get train data
def data_load(path, hf=False, vf=False):
xs = []
ts = []
paths = []
for dir_path in glob(path + '/*'):
for path in glob(dir_path + '/*'):
x = cv2.imread(path)
x = cv2.resize(x, (img_width, img_height)).astype(np.float32)
x /= 255.
x = x[..., ::-1]
xs.append(x)
gt_path = path.replace("images", "seg_images").replace(".jpg", ".png")
gt = cv2.imread(gt_path)
gt = cv2.resize(gt, (out_width, out_height), interpolation=cv2.INTER_NEAREST)
t = np.zeros((out_height, out_width, 1), dtype=np.float)
for i , (label, vs) in enumerate(CLS.items()):
ind = (gt[...,0] == vs[0]) * (gt[...,1] == vs[1]) * (gt[...,2] == vs[2])
t[ind] = 1
#ind = (gt[..., 0] == 0) * (gt[..., 1] == 0) * (gt[..., 2] == 0)
#ind = np.where(ind == True)
#t[ind[0], ind[1], 0] = 1
#ind = (gt[...,0] > 0) + (gt[..., 1] > 0) + (gt[...,2] > 0)
#t[ind] = 1
#print(gt_path)
#import matplotlib.pyplot as plt
#plt.imshow(t, cmap='gray')
#plt.show()
ts.append(t)
paths.append(path)
if hf:
xs.append(x[:, ::-1])
ts.append(t[:, ::-1])
paths.append(path)
if vf:
xs.append(x[::-1])
ts.append(t[::-1])
paths.append(path)
if hf and vf:
xs.append(x[::-1, ::-1])
ts.append(t[::-1, ::-1])
paths.append(path)
xs = np.array(xs)
ts = np.array(ts)
return xs, ts, paths
def arg_parse():
parser = argparse.ArgumentParser(description='CNN implemented with Keras')
parser.add_argument('--train', dest='train', action='store_true')
parser.add_argument('--test', dest='test', action='store_true')
args = parser.parse_args()
return args
# main
if __name__ == '__main__':
args = arg_parse()
#if args.train:
# train()
#if args.test:
# test()
if not (args.train or args.test):
print("please select train or test flag")
print("train: python main.py --train")
print("test: python main.py --test")
print("both: python main.py --train --test")