-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
245 lines (200 loc) · 9.05 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""
This is sample dataloader script for robust multimodal fusion GAN
This dataloader script is for nyu_v2 dataset where
The sparse depth and ground truth depths are stored as h5 file, and
The rgb image is stored as a png
"""
import glob
import random
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from utils import *
import cv2
import h5py
from PIL import Image
import torchvision.transforms as transforms
def read_gt_depth(path):
file = h5py.File(path, "r")
gt_depth = np.array(file['depth_gt'])
return gt_depth
def read_sparse_depth(path):
file = h5py.File(path, "r")
sparse_depth = np.array(file['lidar'])
return sparse_depth
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = os.path.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(
entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def paired_paths_from_meta_info_file(folders, keys, meta_info_file,
filename_tmpl):
"""Generate paired paths from an meta information file.
Each line in the meta information file contains the image names and
image shape (usually for gt), separated by a white space.
Example of an meta information file:
```
0001.png (228,304,1)
0002.png (228,304,1)
```
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder, rgb_foldar].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt', 'rgb'].
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 3, (
'The len of folders should be 3 with [input_folder, gt_folder, rgb_folder]. '
f'But got {len(folders)}')
assert len(keys) == 3, (
'The len of keys should be 2 with [input_key, gt_key, rgb_key]. '
f'But got {len(keys)}')
input_folder, gt_folder, rgb_folder = folders
input_key, gt_key, rgb_key = keys
with open(meta_info_file, 'r') as fin:
gt_names = [line.split(' ')[0] for line in fin]
paths = []
rgb_ext = '.png'
depth_ext = '.h5'
for basename in gt_names:
input_name = f'{filename_tmpl.format(basename)}{depth_ext}'
rgb_name = f'{filename_tmpl.format(basename)}{rgb_ext}'
gt_name = f'{filename_tmpl.format(basename)}{depth_ext}'
input_path = os.path.join(input_folder, input_name)
rgb_path = os.path.join(rgb_folder, rgb_name)
gt_path = os.path.join(gt_folder, gt_name)
paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path),
(f'{rgb_key}_path', rgb_path)]))
return paths
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt', 'rgb'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 3, (
'The len of folders should be 3 with [input_folder, gt_folder, rgb_folder]. '
f'But got {len(folders)}')
assert len(keys) == 3, (
'The len of keys should be 3 with [input_key, gt_key, rgb_key]. '
f'But got {len(keys)}')
input_folder, gt_folder, rgb_folder = folders
input_key, gt_key, rgb_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
rgb_paths = list(scandir(rgb_folder))
assert len(input_paths) == len(gt_paths), (
f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
assert len(input_paths) == len(rgb_paths), (
f'{input_key} and {rgb_key} datasets have different number of images: '
f'{len(input_paths)}, {len(rgb_paths)}.')
paths = []
rgb_ext = '.png'
for gt_path in gt_paths:
basename, ext = os.path.splitext(os.path.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
rgb_name = f'{filename_tmpl.format(basename)}{rgb_ext}'
input_path = os.path.join(input_folder, input_name)
gt_path = os.path.join(gt_folder, gt_path)
rgb_path = os.path.join(rgb_folder, rgb_path)
paths.append(
dict([(f'{input_key}_path', input_path),
(f'{gt_key}_path', gt_path),
(f'{rgb_key}_path', rgb_path)]))
return paths
class PairedImageDataset(Dataset):
def __init__(self, root, opt, hr_shape):
#We cannot use torch.Transforms because transforms.ToTensor() normalizes the image assuming its a 3 channel uint8 RGB image
super(PairedImageDataset, self).__init__()
self.opt = opt
# assumption is that the sparse depth is in "lidar" folder
# ground truth depth is in "depth_gt" folder
# and rgb image is in "image_rgb" folder
self.gt_folder, self.lq_folder, self.rgb_folder = os.path.join(root,'depth_gt'), os.path.join(root,'sparse_depth'), os.path.join(root,'image_rgb')
self.filename_tmpl = '{}'
self.transform_rgb = transforms.Compose([transforms.Pad((0,6,0,6),fill=0),
transforms.ToTensor(),
transforms.Normalize(mean = rgb_mean,
std = rgb_std),
])
if self.opt.meta_info_file is not None:
self.meta_file = os.path.join(root, self.opt.meta_info_file)
self.paths = paired_paths_from_meta_info_file(
[self.lq_folder, self.gt_folder, self.rgb_folder], ['lq', 'gt', 'rgb'],
self.meta_file, self.filename_tmpl)
else:
self.paths = paired_paths_from_folder(
[self.lq_folder, self.gt_folder, self.rgb_folder], ['lq', 'gt', 'rgb'],
self.filename_tmpl)
def __getitem__(self, index):
# Load gt and lq depths. Dimension order: HW; channel: Grayscale;
# Depth range: [0, 9.999], float32.
gt_path = self.paths[index]['gt_path']
img_hi = read_gt_depth(gt_path)
temp_hi = torch.from_numpy(img_hi)
img_hi = F.pad(temp_hi,(0,0,6,6),'constant',0)
lq_path = self.paths[index]['lq_path']
img_lo = read_sparse_depth(lq_path)
temp_lo = torch.from_numpy(img_lo)
img_lo = F.pad(temp_lo,(0,0,6,6),'constant',0)
rgb_path = self.paths[index]['rgb_path']
img_color = Image.open(rgb_path)
# depth transformation
gt = (img_hi-depth_mean)/depth_std
sparse = (img_lo-sparse_mean)/sparse_std
# RGB transformation
img_rgb = self.transform_rgb(img_color)
return {
'sparse': sparse,
'gt': gt,
'rgb': img_rgb
}
def __len__(self):
return len(self.paths)