-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
101 lines (81 loc) · 3.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
from torchvision.models import vgg16, vgg19
######################################
######## Adverserial BCE loss ########
######################################
class AdvLoss(nn.Module):
'''BCE for True and False reals'''
def __init__(self, alpha=1):
super().__init__()
self.loss_fn = nn.BCEWithLogitsLoss()
self.alpha=alpha
def forward(self, pred, target):
return self.alpha* self.loss_fn(pred, target)
######################################
######### Pixel-wise MSE loss ########
######################################
class PixLoss(nn.Module):
'''Pixel-wise MSE loss for images'''
def __init__(self, alpha=20):
super().__init__()
self.alpha=alpha
def forward(self, fake, real):
return self.alpha* torch.mean((fake - real)**2)
######################################
######## Model-based loss ########
######################################
class ModelBasedLoss(nn.Module):
'''Model based loss for generator'''
def __init__(self, alpha=2, name='vgg19', device='cuda:0'):
super().__init__()
model = self.__loadModel__(name)
self.model = self.__freeze__(model).to(device)
self.alpha=alpha
@staticmethod
def __loadModel__(name='vgg19'):
if name=='vgg16':
model = vgg16(pretrained=True).features[:-1]
elif name=='vgg19':
model = vgg19(pretrained=True).features[:-1]
return model.eval()
@staticmethod
def __freeze__(model):
for p in model.parameters():
p.requires_grad = False
return model
def forward(self, fake, real):
pred = self.model(fake)
target = self.model(real)
return self.alpha* torch.mean((pred - target)**2)
######################################
######## Generator loss ##########
######################################
class GeneratorLoss(nn.Module):
'''Generator loss'''
def __init__(self, alpha=0.001, beta=0.006,
gamma=1, model='vgg19', device='cuda:0'):
super().__init__()
self.bce = AdvLoss(alpha)
self.fb_mse = ModelBasedLoss(beta, model, device)
self.mse = PixLoss(gamma)
def forward(self, fake_pred, fake, real):
fake_target = torch.ones_like(fake_pred)
loss = (self.bce(fake_pred, fake_target)\
+ self.fb_mse(fake, real)\
+ self.mse(fake, real))#/3
return loss
######################################
####### Discriminator loss ########
######################################
class DiscriminatorLoss(nn.Module):
'''Discriminator loss'''
def __init__(self, alpha=1):
super().__init__()
self.bce = AdvLoss(alpha)
def forward(self, fake_pred, real_pred):
fake_target = torch.zeros_like(fake_pred)
real_target = torch.ones_like(real_pred)
loss = (self.bce(fake_pred, fake_target)\
+ self.bce(real_pred, real_target))/2
return loss