-
Notifications
You must be signed in to change notification settings - Fork 47
/
utils.py
92 lines (75 loc) · 2.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import math
import random
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
class AverageMeter(object):
def __init__(self):
self.val = None
self.sum = None
self.cnt = None
self.avg = None
self.ema = None
self.initialized = False
def update(self, val, n=1):
if not self.initialized:
self.initialize(val, n)
else:
self.add(val, n)
def initialize(self, val, n):
self.val = val
self.sum = val * n
self.cnt = n
self.avg = val
self.ema = val
self.initialized = True
def add(self, val, n):
self.val = val
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
self.ema = self.ema * 0.99 + self.val * 0.01
def inter_and_union(pred, mask, num_class):
pred = np.asarray(pred, dtype=np.uint8).copy()
mask = np.asarray(mask, dtype=np.uint8).copy()
# 255 -> 0
pred += 1
mask += 1
pred = pred * (mask > 0)
inter = pred * (pred == mask)
(area_inter, _) = np.histogram(inter, bins=num_class, range=(1, num_class))
(area_pred, _) = np.histogram(pred, bins=num_class, range=(1, num_class))
(area_mask, _) = np.histogram(mask, bins=num_class, range=(1, num_class))
area_union = area_pred + area_mask - area_inter
return (area_inter, area_union)
def preprocess(image, mask, flip=False, scale=None, crop=None):
if flip:
if random.random() < 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
if scale:
w, h = image.size
rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2))
random_scale = math.pow(2, rand_log_scale)
new_size = (int(round(w * random_scale)), int(round(h * random_scale)))
image = image.resize(new_size, Image.ANTIALIAS)
mask = mask.resize(new_size, Image.NEAREST)
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = data_transforms(image)
mask = torch.LongTensor(np.array(mask).astype(np.int64))
if crop:
h, w = image.shape[1], image.shape[2]
pad_tb = max(0, crop[0] - h)
pad_lr = max(0, crop[1] - w)
image = torch.nn.ZeroPad2d((0, pad_lr, 0, pad_tb))(image)
mask = torch.nn.ConstantPad2d((0, pad_lr, 0, pad_tb), 255)(mask)
h, w = image.shape[1], image.shape[2]
i = random.randint(0, h - crop[0])
j = random.randint(0, w - crop[1])
image = image[:, i:i + crop[0], j:j + crop[1]]
mask = mask[i:i + crop[0], j:j + crop[1]]
return image, mask