forked from rykov8/ssd_keras
-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathdata_td500.py
79 lines (63 loc) · 2.55 KB
/
data_td500.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import os
import glob
from thirdparty.get_image_size import get_image_size
from ssd_data import BaseGTUtility
def rot_matrix(theta):
ct, st = np.cos(theta), np.sin(theta)
return np.array([[ct, -st],[st, ct]])
class GTUtility(BaseGTUtility):
"""Utility for MSRA-TD500 dataset.
# Arguments
data_path: Path to ground truth and image data.
test: Boolean for using training or test set.
"""
def __init__(self, data_path, test=False):
self.data_path = data_path
if test:
gt_path = os.path.join(data_path, 'test')
else:
gt_path = os.path.join(data_path, 'train')
self.gt_path = gt_path
self.image_path = image_path = gt_path
self.classes = ['Background', 'Text']
self.image_names = []
self.data = []
self.text = []
for image_file_name in sorted(glob.glob(image_path+'/*.JPG')):
image_name = os.path.split(image_file_name)[1]
img_width, img_height = get_image_size(image_file_name)
boxes = []
text = []
gt_file_name = os.path.splitext(image_name)[0] + '.gt'
with open(os.path.join(gt_path, gt_file_name), 'r') as f:
for line in f:
line_split = line.strip().split(' ')
# line_split = [index, difficult, x, y, w, h, theta]
# skip difficult boxes
if int(line_split[1]) == 1:
#continue
pass
cx, cy, w, h, theta = [float(v) for v in line_split[2:]]
box = np.array([[-w,h],[w,h],[w,-h],[-w,-h]]) / 2.
box = np.dot(box, rot_matrix(-theta))
box += [cx + w/2., cy + h/2.]
box = list(box.flatten())
box = box + [1]
boxes.append(box)
text.append('')
# only images with boxes
if len(boxes) == 0:
continue
boxes = np.empty((0,8+self.num_classes))
else:
boxes = np.asarray(boxes)
boxes[:,0:8:2] /= img_width
boxes[:,1:8:2] /= img_height
self.image_names.append(image_name)
self.data.append(boxes)
self.text.append(text)
self.init()
if __name__ == '__main__':
gt_util = GTUtility('data/MSRA-TD500/', test=True)
print(gt_util.data)