forked from PINTO0309/DeepLearningMugenKnock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
semaseg_dataset_pytorch.py
96 lines (75 loc) · 2.52 KB
/
semaseg_dataset_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn.functional as F
import argparse
import cv2
import numpy as np
from glob import glob
num_classes = 2
img_height, img_width = 64, 64#572, 572
out_height, out_width = 64, 64#388, 388
GPU = False
torch.manual_seed(0)
CLS = {'akahara': [0,0,128],
'madara': [0,128,0]}
# get train data
def data_load(path, hf=False, vf=False):
xs = []
ts = []
paths = []
for dir_path in glob(path + '/*'):
for path in glob(dir_path + '/*'):
x = cv2.imread(path)
x = cv2.resize(x, (img_width, img_height)).astype(np.float32)
x /= 255.
x = x[..., ::-1]
xs.append(x)
gt_path = path.replace("images", "seg_images").replace(".jpg", ".png")
gt = cv2.imread(gt_path)
gt = cv2.resize(gt, (out_width, out_height), interpolation=cv2.INTER_NEAREST)
t = np.zeros((out_height, out_width), dtype=np.int)
for i, (_, vs) in enumerate(CLS.items()):
ind = (gt[...,0] == vs[0]) * (gt[...,1] == vs[1]) * (gt[...,2] == vs[2])
t[ind] = i + 1
#print(gt_path)
#import matplotlib.pyplot as plt
#plt.subplot(1,2,1)
#plt.imshow(x)
#plt.subplot(1,2,2)
#plt.imshow(t, vmin=0, vmax=2)
#plt.show()
ts.append(t)
paths.append(path)
if hf:
xs.append(x[:, ::-1])
ts.append(t[:, ::-1])
paths.append(path)
if vf:
xs.append(x[::-1])
ts.append(t[::-1])
paths.append(path)
if hf and vf:
xs.append(x[::-1, ::-1])
ts.append(t[::-1, ::-1])
paths.append(path)
xs = np.array(xs)
ts = np.array(ts)
xs = xs.transpose(0,3,1,2)
return xs, ts, paths
def arg_parse():
parser = argparse.ArgumentParser(description='CNN implemented with Keras')
parser.add_argument('--train', dest='train', action='store_true')
parser.add_argument('--test', dest='test', action='store_true')
args = parser.parse_args()
return args
# main
if __name__ == '__main__':
args = arg_parse()
#if args.train:
# train()
#if args.test:
# test()
if not (args.train or args.test):
print("please select train or test flag")
print("train: python main.py --train")
print("test: python main.py --test")
print("both: python main.py --train --test")