-
Notifications
You must be signed in to change notification settings - Fork 3
/
helper.py
82 lines (72 loc) · 2.68 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from torch.utils import data
from torchvision import transforms
import numpy as np
import os
import cv2
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
class semantic_dataset(data.Dataset):
def __init__(self, split = 'test', transform = None):
self.void_labels = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
self.valid_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(19)))
self.split = split
self.img_path = 'testing/image_2/'
self.mask_path = None
if self.split == 'train':
self.img_path = 'training/image_2/'
self.mask_path = 'training/semantic/'
self.transform = transform
self.img_list = self.get_filenames(self.img_path)
self.mask_list = None
if self.split == 'train':
self.mask_list = self.get_filenames(self.mask_path)
def __len__(self):
return(len(self.img_list))
def __getitem__(self, idx):
img = cv2.imread(self.img_list[idx])
img = cv2.resize(img, (1242, 376))
mask = None
if self.split == 'train':
mask = cv2.imread(self.mask_list[idx], cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (1242, 376))
mask = self.encode_segmap(mask)
assert(mask.shape == (376, 1242))
if self.transform:
img = self.transform(img)
assert(img.shape == (3, 376, 1242))
else :
assert(img.shape == (376, 1242, 3))
if self.split == 'train':
return img, mask
else :
return img
def encode_segmap(self, mask):
'''
Sets void classes to zero so they won't be considered for training
'''
for voidc in self.void_labels :
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels :
mask[mask == validc] = self.class_map[validc]
return mask
def get_filenames(self, path):
files_list = list()
for filename in os.listdir(path):
files_list.append(os.path.join(path, filename))
return files_list