-
Notifications
You must be signed in to change notification settings - Fork 120
/
Copy pathtrain_srgan.py
264 lines (212 loc) · 11.3 KB
/
train_srgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import time
import torch.backends.cudnn as cudnn
from torch import nn
from models import Generator, Discriminator, TruncatedVGG19
from datasets import SRDataset
from utils import *
# Data parameters
data_folder = './' # folder with JSON data files
crop_size = 96 # crop size of target HR images
scaling_factor = 4 # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor
# Generator parameters
large_kernel_size_g = 9 # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size_g = 3 # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels_g = 64 # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks_g = 16 # number of residual blocks
srresnet_checkpoint = "./checkpoint_srresnet.pth.tar" # filepath of the trained SRResNet checkpoint used for initialization
# Discriminator parameters
kernel_size_d = 3 # kernel size in all convolutional blocks
n_channels_d = 64 # number of output channels in the first convolutional block, after which it is doubled in every 2nd block thereafter
n_blocks_d = 8 # number of convolutional blocks
fc_size_d = 1024 # size of the first fully connected layer
# Learning parameters
checkpoint = None # path to model (SRGAN) checkpoint, None if none
batch_size = 16 # batch size
start_epoch = 0 # start at this epoch
iterations = 2e5 # number of training iterations
workers = 4 # number of workers for loading data in the DataLoader
vgg19_i = 5 # the index i in the definition for VGG loss; see paper or models.py
vgg19_j = 4 # the index j in the definition for VGG loss; see paper or models.py
beta = 1e-3 # the coefficient to weight the adversarial loss in the perceptual loss
print_freq = 500 # print training status once every __ batches
lr = 1e-4 # learning rate
grad_clip = None # clip if gradients are exploding
# Default device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
def main():
"""
Training.
"""
global start_epoch, epoch, checkpoint, srresnet_checkpoint
# Initialize model or load checkpoint
if checkpoint is None:
# Generator
generator = Generator(large_kernel_size=large_kernel_size_g,
small_kernel_size=small_kernel_size_g,
n_channels=n_channels_g,
n_blocks=n_blocks_g,
scaling_factor=scaling_factor)
# Initialize generator network with pretrained SRResNet
generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint)
# Initialize generator's optimizer
optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()),
lr=lr)
# Discriminator
discriminator = Discriminator(kernel_size=kernel_size_d,
n_channels=n_channels_d,
n_blocks=n_blocks_d,
fc_size=fc_size_d)
# Initialize discriminator's optimizer
optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()),
lr=lr)
else:
checkpoint = torch.load(checkpoint)
start_epoch = checkpoint['epoch'] + 1
generator = checkpoint['generator']
discriminator = checkpoint['discriminator']
optimizer_g = checkpoint['optimizer_g']
optimizer_d = checkpoint['optimizer_d']
print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1))
# Truncated VGG19 network to be used in the loss calculation
truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
truncated_vgg19.eval()
# Loss functions
content_loss_criterion = nn.MSELoss()
adversarial_loss_criterion = nn.BCEWithLogitsLoss()
# Move to default device
generator = generator.to(device)
discriminator = discriminator.to(device)
truncated_vgg19 = truncated_vgg19.to(device)
content_loss_criterion = content_loss_criterion.to(device)
adversarial_loss_criterion = adversarial_loss_criterion.to(device)
# Custom dataloaders
train_dataset = SRDataset(data_folder,
split='train',
crop_size=crop_size,
scaling_factor=scaling_factor,
lr_img_type='imagenet-norm',
hr_img_type='imagenet-norm')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers,
pin_memory=True)
# Total number of epochs to train for
epochs = int(iterations // len(train_loader) + 1)
# Epochs
for epoch in range(start_epoch, epochs):
# At the halfway point, reduce learning rate to a tenth
if epoch == int((iterations / 2) // len(train_loader) + 1):
adjust_learning_rate(optimizer_g, 0.1)
adjust_learning_rate(optimizer_d, 0.1)
# One epoch's training
train(train_loader=train_loader,
generator=generator,
discriminator=discriminator,
truncated_vgg19=truncated_vgg19,
content_loss_criterion=content_loss_criterion,
adversarial_loss_criterion=adversarial_loss_criterion,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
epoch=epoch)
# Save checkpoint
torch.save({'epoch': epoch,
'generator': generator,
'discriminator': discriminator,
'optimizer_g': optimizer_g,
'optimizer_d': optimizer_d},
'checkpoint_srgan.pth.tar')
def train(train_loader, generator, discriminator, truncated_vgg19, content_loss_criterion, adversarial_loss_criterion,
optimizer_g, optimizer_d, epoch):
"""
One epoch's training.
:param train_loader: train dataloader
:param generator: generator
:param discriminator: discriminator
:param truncated_vgg19: truncated VGG19 network
:param content_loss_criterion: content loss function (Mean Squared-Error loss)
:param adversarial_loss_criterion: adversarial loss function (Binary Cross-Entropy loss)
:param optimizer_g: optimizer for the generator
:param optimizer_d: optimizer for the discriminator
:param epoch: epoch number
"""
# Set to train mode
generator.train()
discriminator.train() # training mode enables batch normalization
batch_time = AverageMeter() # forward prop. + back prop. time
data_time = AverageMeter() # data loading time
losses_c = AverageMeter() # content loss
losses_a = AverageMeter() # adversarial loss in the generator
losses_d = AverageMeter() # adversarial loss in the discriminator
start = time.time()
# Batches
for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
data_time.update(time.time() - start)
# Move to default device
lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed
hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), imagenet-normed
# GENERATOR UPDATE
# Generate
sr_imgs = generator(lr_imgs) # (N, 3, 96, 96), in [-1, 1]
sr_imgs = convert_image(sr_imgs, source='[-1, 1]', target='imagenet-norm') # (N, 3, 96, 96), imagenet-normed
# Calculate VGG feature maps for the super-resolved (SR) and high resolution (HR) images
sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs)
hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach() # detached because they're constant, targets
# Discriminate super-resolved (SR) images
sr_discriminated = discriminator(sr_imgs) # (N)
# Calculate the Perceptual loss
content_loss = content_loss_criterion(sr_imgs_in_vgg_space, hr_imgs_in_vgg_space)
adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.ones_like(sr_discriminated))
perceptual_loss = content_loss + beta * adversarial_loss
# Back-prop.
optimizer_g.zero_grad()
perceptual_loss.backward()
# Clip gradients, if necessary
if grad_clip is not None:
clip_gradient(optimizer_g, grad_clip)
# Update generator
optimizer_g.step()
# Keep track of loss
losses_c.update(content_loss.item(), lr_imgs.size(0))
losses_a.update(adversarial_loss.item(), lr_imgs.size(0))
# DISCRIMINATOR UPDATE
# Discriminate super-resolution (SR) and high-resolution (HR) images
hr_discriminated = discriminator(hr_imgs)
sr_discriminated = discriminator(sr_imgs.detach())
# But didn't we already discriminate the SR images earlier, before updating the generator (G)? Why not just use that here?
# Because, if we used that, we'd be back-propagating (finding gradients) over the G too when backward() is called
# It's actually faster to detach the SR images from the G and forward-prop again, than to back-prop. over the G unnecessarily
# See FAQ section in the tutorial
# Binary Cross-Entropy loss
adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \
adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated))
# Back-prop.
optimizer_d.zero_grad()
adversarial_loss.backward()
# Clip gradients, if necessary
if grad_clip is not None:
clip_gradient(optimizer_d, grad_clip)
# Update discriminator
optimizer_d.step()
# Keep track of loss
losses_d.update(adversarial_loss.item(), hr_imgs.size(0))
# Keep track of batch times
batch_time.update(time.time() - start)
# Reset start time
start = time.time()
# Print status
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]----'
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
'Cont. Loss {loss_c.val:.4f} ({loss_c.avg:.4f})----'
'Adv. Loss {loss_a.val:.4f} ({loss_a.avg:.4f})----'
'Disc. Loss {loss_d.val:.4f} ({loss_d.avg:.4f})'.format(epoch,
i,
len(train_loader),
batch_time=batch_time,
data_time=data_time,
loss_c=losses_c,
loss_a=losses_a,
loss_d=losses_d))
del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated # free some memory since their histories may be stored
if __name__ == '__main__':
main()