-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfcn_model.py
139 lines (113 loc) · 5.56 KB
/
fcn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from torch import nn
import numpy as np
import torch
import torch.nn.functional as F
import resnet
class FCN(nn.Module):
def __init__(self, num_rotations=16, fast=False, dilation=False):
super().__init__()
self.num_rotations = num_rotations
self.use_cuda = True
self.fast = fast
#modules = list(models.resnet18().children())[:-5]
#self.backbone = nn.Sequential(*modules)
self.backbone = resnet.resnet18(num_input_channels=1, dilation=dilation)#models.resnet18()
#backbone = resnet.resnet18(num_input_channels=3, num_classes=1)
#self.resnet.cuda()
#self.backbone = backbone.features
def decoder(n, out):
return nn.Sequential(
nn.Conv2d(n, 128, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(128, 32, 1, 1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(32, out, 1, 1),
)
self.decoder = decoder(512, num_rotations if self.fast else 1)
#self.fcn = deeplabv3_resnet50(pretrained=False, num_classes=16)
#self.head = nn.Sequential(
# nn.Conv2d(32, 32, 1, 1),
# nn.BatchNorm2d(32),
# nn.ReLU(),
# nn.Conv2d(32, num_rotations if fast else 1, 1, 1)
#)
def cat_grid(self, input, affine_grid=None):
x = torch.abs(torch.linspace(-0.5, 0.5, steps=input.shape[-2])).cuda() # side
y = torch.tensor(torch.linspace(0, 1, steps=input.shape[-1])).cuda() # forward
grid_x, grid_y = torch.meshgrid(x, y)
grid_x = grid_x.unsqueeze(0).unsqueeze(0)
grid_y = grid_y.unsqueeze(0).unsqueeze(0)
grid_x = grid_x.repeat(len(input), 1, 1, 1)
grid_y = grid_y.repeat(len(input), 1, 1, 1)
grid = torch.cat([grid_x, grid_y], 1)
if affine_grid is not None:
flow_grid = F.affine_grid(affine_grid, input.size())
grid = F.grid_sample(grid, flow_grid, mode='nearest')
x = torch.cat([input, grid], 1)
return x
def forward(self, x, force_rotations=-1):
bs = len(x)
output_prob = []
if force_rotations == -1 and (self.num_rotations == 1 or self.fast):
#h = self.backbone(self.cat_grid(x))
#g = self.end(h)
g = self.decoder(self.backbone.features(x))
#h = self.fcn(x)['out']
#g = self.head(torch.cat([h, g], 1))
return g
else:
rotations = force_rotations if force_rotations > 0 else self.num_rotations
for rotate_idx in range(rotations):
rotate_theta = np.radians(rotate_idx * (360 / rotations))
# Compute sample grid for rotation BEFORE neural network
affine_mat_before = np.asarray(
[[np.cos(-rotate_theta), np.sin(-rotate_theta), 0], [-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
affine_mat_before.shape = (2, 3, 1)
affine_mat_before = torch.from_numpy(affine_mat_before).permute(2, 0, 1).float()
affine_mat_before = affine_mat_before.repeat(bs, 1, 1)
#print(affine_mat_before.shape, x.shape)
#print(affine_mat_before.is_cuda, x.is_cuda)
if self.use_cuda:
affine_mat_before = affine_mat_before.cuda()
flow_grid_before = F.affine_grid(affine_mat_before, x.size())
#flow_grid_vit = F.affine_grid(affine_mat_before, vit_h.size())
else:
affine_mat_before = affine_mat_before.detach()
flow_grid_before = F.affine_grid(affine_mat_before, x.size())
# Rotate images clockwise
if self.use_cuda:
rotate_depth = F.grid_sample(x.detach().cuda(), flow_grid_before, mode='nearest')
#rotate_vit_h = F.grid_sample(vit_h, flow_grid_vit, mode='nearest')
else:
rotate_depth = F.grid_sample(x.detach(), flow_grid_before, mode='nearest')
output_map = self.decoder(self.backbone.features(rotate_depth))
# Compute sample grid for rotation AFTER branches
affine_mat_after = np.asarray(
[[np.cos(rotate_theta), np.sin(rotate_theta), 0], [-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
affine_mat_after.shape = (2, 3, 1)
affine_mat_after = torch.from_numpy(affine_mat_after).permute(2, 0, 1).float()
affine_mat_after = affine_mat_after.repeat(bs, 1, 1)
if self.use_cuda:
flow_grid_after = F.affine_grid(affine_mat_after.detach().cuda(),
output_map.size())
else:
flow_grid_after = F.affine_grid(affine_mat_after.detach(),
output_map.size())
# Forward pass through branches, undo rotation on output predictions, upsample results
h = F.grid_sample(output_map, flow_grid_after, mode='nearest')
output_prob.append(h)
out = torch.stack(output_prob) # R x N x 1 x H x W
out = out.squeeze(2) # R x N x H x W
out = out.permute(1, 0, 2, 3)
return out
if __name__ == '__main__':
model = FCN()
model.cuda()
model.eval()
while True:
y = model(torch.rand((1, 3, 224, 224)).cuda())
print(torch.stack(y).shape)