forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
db_file_reader.py
182 lines (152 loc) · 6.45 KB
/
db_file_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
## @package db_file_reader
# Module caffe2.python.db_file_reader
from caffe2.python import core, scope, workspace, _import_c_extension as C
from caffe2.python.dataio import Reader
from caffe2.python.dataset import Dataset
from caffe2.python.schema import from_column_list
import os
class DBFileReader(Reader):
default_name_suffix = 'db_file_reader'
"""Reader reads from a DB file.
Example usage:
db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')
Args:
db_path: str.
db_type: str. DB type of file. A db_type is registed by
`REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
name: str or None. Name of DBFileReader.
Optional name to prepend to blobs that will store the data.
Default to '<db_name>_<default_name_suffix>'.
batch_size: int.
How many examples are read for each time the read_net is run.
loop_over: bool.
If True given, will go through examples in random order endlessly.
field_names: List[str]. If the schema.field_names() should not in
alphabetic order, it must be specified.
Otherwise, schema will be automatically restored with
schema.field_names() sorted in alphabetic order.
"""
def __init__(
self,
db_path,
db_type,
name=None,
batch_size=100,
loop_over=False,
field_names=None,
):
assert db_path is not None, "db_path can't be None."
assert db_type in C.registered_dbs(), \
"db_type [{db_type}] is not available. \n" \
"Choose one of these: {registered_dbs}.".format(
db_type=db_type,
registered_dbs=C.registered_dbs(),
)
self.db_path = os.path.expanduser(db_path)
self.db_type = db_type
self.name = name or '{db_name}_{default_name_suffix}'.format(
db_name=self._extract_db_name_from_db_path(),
default_name_suffix=self.default_name_suffix,
)
self.batch_size = batch_size
self.loop_over = loop_over
# Before self._init_reader_schema(...),
# self.db_path and self.db_type are required to be set.
super(DBFileReader, self).__init__(self._init_reader_schema(field_names))
self.ds = Dataset(self._schema, self.name + '_dataset')
self.ds_reader = None
def _init_name(self, name):
return name or self._extract_db_name_from_db_path(
) + '_db_file_reader'
def _init_reader_schema(self, field_names=None):
"""Restore a reader schema from the DB file.
If `field_names` given, restore scheme according to it.
Overwise, loade blobs from the DB file into the workspace,
and restore schema from these blob names.
It is also assumed that:
1). Each field of the schema have corresponding blobs
stored in the DB file.
2). Each blob loaded from the DB file corresponds to
a field of the schema.
3). field_names in the original schema are in alphabetic order,
since blob names loaded to the workspace from the DB file
will be in alphabetic order.
Load a set of blobs from a DB file. From names of these blobs,
restore the DB file schema using `from_column_list(...)`.
Returns:
schema: schema.Struct. Used in Reader.__init__(...).
"""
if field_names:
return from_column_list(field_names)
if self.db_type == "log_file_db":
assert os.path.exists(self.db_path), \
'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
with core.NameScope(self.name):
# blob_prefix is for avoiding name conflict in workspace
blob_prefix = scope.CurrentNameScope()
workspace.RunOperatorOnce(
core.CreateOperator(
'Load',
[],
[],
absolute_path=True,
db=self.db_path,
db_type=self.db_type,
load_all=True,
add_prefix=blob_prefix,
)
)
col_names = [
blob_name[len(blob_prefix):] for blob_name in sorted(workspace.Blobs())
if blob_name.startswith(blob_prefix)
]
schema = from_column_list(col_names)
return schema
def setup_ex(self, init_net, finish_net):
"""From the Dataset, create a _DatasetReader and setup a init_net.
Make sure the _init_field_blobs_as_empty(...) is only called once.
Because the underlying NewRecord(...) creats blobs by calling
NextScopedBlob(...), so that references to previously-initiated
empty blobs will be lost, causing accessibility issue.
"""
if self.ds_reader:
self.ds_reader.setup_ex(init_net, finish_net)
else:
self._init_field_blobs_as_empty(init_net)
self._feed_field_blobs_from_db_file(init_net)
self.ds_reader = self.ds.random_reader(
init_net,
batch_size=self.batch_size,
loop_over=self.loop_over,
)
self.ds_reader.sort_and_shuffle(init_net)
self.ds_reader.computeoffset(init_net)
def read(self, read_net):
assert self.ds_reader, 'setup_ex must be called first'
return self.ds_reader.read(read_net)
def _init_field_blobs_as_empty(self, init_net):
"""Initialize dataset field blobs by creating an empty record"""
with core.NameScope(self.name):
self.ds.init_empty(init_net)
def _feed_field_blobs_from_db_file(self, net):
"""Load from the DB file at db_path and feed dataset field blobs"""
if self.db_type == "log_file_db":
assert os.path.exists(self.db_path), \
'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
net.Load(
[],
self.ds.get_blobs(),
db=self.db_path,
db_type=self.db_type,
absolute_path=True,
source_blob_names=self.ds.field_names(),
)
def _extract_db_name_from_db_path(self):
"""Extract DB name from DB path
E.g. given self.db_path=`/tmp/sample.db`, or
self.db_path = `dper_test_data/cached_reader/sample.db`
it returns `sample`.
Returns:
db_name: str.
"""
return os.path.basename(self.db_path).rsplit('.', 1)[0]