forked from Aaditya-Singh/SAFIN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
function.py
140 lines (115 loc) · 5.75 KB
/
function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def mean_normalization(feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size))/std.expand(size)
return normalized_feat
class Attention(nn.Module):
def __init__(self, num_features):
super(Attention, self).__init__()
self.query_conv = nn.Conv2d(num_features, num_features, (1, 1))
self.key_conv = nn.Conv2d(num_features, num_features, (1, 1))
self.value_conv = nn.Conv2d(num_features, num_features, (1, 1))
self.softmax = nn.Softmax(dim = -1)
nn.init.xavier_uniform_(self.query_conv.weight)
nn.init.uniform_(self.query_conv.bias, 0.0, 1.0)
nn.init.xavier_uniform_(self.key_conv.weight)
nn.init.uniform_(self.key_conv.bias, 0.0, 1.0)
nn.init.xavier_uniform_(self.value_conv.weight)
nn.init.uniform_(self.value_conv.bias, 0.0, 1.0)
def forward(self, content_feat, style_feat):
Query = self.query_conv(mean_normalization(content_feat))
Key = self.key_conv(mean_normalization(style_feat))
Value = self.value_conv(style_feat)
batch_size, channels, height_c, width_c = Query.size()
Query = Query.view(batch_size, -1, width_c * height_c).permute(0, 2, 1)
batch_size, channels, height_s, width_s = Key.size()
Key = Key.view(batch_size, -1, width_s * height_s)
Attention_Weights = self.softmax(torch.bmm(Query, Key))
Value = Value.view(batch_size, -1, width_s * height_s)
Output = torch.bmm(Value, Attention_Weights.permute(0, 2, 1))
Output = Output.view(batch_size, channels, height_c, width_c)
return Output
class SAFIN(nn.Module):
def __init__(self, num_features):
super().__init__()
self.num_features = num_features
self.shared_weight = nn.Parameter(torch.Tensor(num_features), requires_grad=True)
self.shared_bias = nn.Parameter(torch.Tensor(num_features), requires_grad=True)
self.shared_pad = nn.ReflectionPad2d((1, 1, 1, 1))
self.gamma_conv = nn.Conv2d(num_features, num_features, (1, 1))
self.beta_conv = nn.Conv2d(num_features, num_features, (1, 1))
self.attention = Attention(num_features)
self.relu = nn.ReLU()
nn.init.ones_(self.shared_weight)
nn.init.zeros_(self.shared_bias)
nn.init.xavier_uniform_(self.gamma_conv.weight)
nn.init.uniform_(self.gamma_conv.bias, 0.0, 1.0)
nn.init.xavier_uniform_(self.beta_conv.weight)
nn.init.uniform_(self.beta_conv.bias, 0.0, 1.0)
def forward(self, content_feat, style_feat, output_shared=False):
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_feat = self.attention(content_feat, style_feat)
style_gamma = self.relu(self.gamma_conv(style_feat))
style_beta = self.relu(self.beta_conv(style_feat))
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
shared_affine_feat = normalized_feat * self.shared_weight.view(1, self.num_features, 1, 1).expand(size) + \
self.shared_bias.view(1, self.num_features, 1, 1).expand(size)
if output_shared:
return shared_affine_feat
output = shared_affine_feat * style_gamma + style_beta
return output
def _calc_feat_flatten_mean_std(feat):
# takes 3D feat (C, H, W), return mean and std of array within channels
assert (feat.size()[0] == 3)
assert (isinstance(feat, torch.FloatTensor))
feat_flatten = feat.view(3, -1)
mean = feat_flatten.mean(dim=-1, keepdim=True)
std = feat_flatten.std(dim=-1, keepdim=True)
return feat_flatten, mean, std
def _mat_sqrt(x):
U, D, V = torch.svd(x)
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
def coral(source, target):
# assume both source and target are 3D array (C, H, W)
# Note: flatten -> f
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
source_f_norm = (source_f - source_f_mean.expand_as(
source_f)) / source_f_std.expand_as(source_f)
source_f_cov_eye = \
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
target_f_norm = (target_f - target_f_mean.expand_as(
target_f)) / target_f_std.expand_as(target_f)
target_f_cov_eye = \
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
source_f_norm_transfer = torch.mm(
_mat_sqrt(target_f_cov_eye),
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
source_f_norm)
)
source_f_transfer = source_f_norm_transfer * \
target_f_std.expand_as(source_f_norm) + \
target_f_mean.expand_as(source_f_norm)
return source_f_transfer.view(source.size())