-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
66 lines (57 loc) · 1.95 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# coding=utf-8
# Copyleft 2019 Project LXRT
import base64
import csv
import sys
import time
import numpy as np
csv.field_size_limit(sys.maxsize)
FIELDNAMES = [
"img_id",
"img_h",
"img_w",
"objects_id",
"objects_conf",
"attrs_id",
"attrs_conf",
"num_boxes",
"boxes",
"features",
]
def load_obj_tsv(fname, topk=None):
"""Load object features from tsv file.
:param fname: The path to the tsv file.
:param topk: Only load features for top K images (lines) in the tsv file.
Will load all the features if topk is either -1 or None.
:return: A list of image object features where each feature is a dict.
See FILENAMES above for the keys in the feature dict.
"""
data = []
start_time = time.time()
print("Start to load Faster-RCNN detected objects from %s" % fname)
with open(fname) as f:
reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
for i, item in enumerate(reader):
for key in ["img_h", "img_w", "num_boxes"]:
item[key] = int(item[key])
boxes = item["num_boxes"]
decode_config = [
("objects_id", (boxes,), np.int64),
("objects_conf", (boxes,), np.float32),
("attrs_id", (boxes,), np.int64),
("attrs_conf", (boxes,), np.float32),
("boxes", (boxes, 4), np.float32),
("features", (boxes, -1), np.float32),
]
for key, shape, dtype in decode_config:
item[key] = np.frombuffer(base64.b64decode(item[key]), dtype=dtype)
item[key] = item[key].reshape(shape)
item[key].setflags(write=False)
data.append(item)
if topk is not None and len(data) == topk:
break
elapsed_time = time.time() - start_time
print(
"Loaded %d images in file %s in %d seconds." % (len(data), fname, elapsed_time)
)
return data