Skip to content

Commit a363bf7

Browse files
author
WuZhe
authored
Add files via upload
1 parent 0d58898 commit a363bf7

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

train.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch.autograd import Variable
4+
5+
import numpy as np
6+
import pdb, os, argparse
7+
from datetime import datetime
8+
9+
from model.CPD_models import CPD_VGG
10+
from model.CPD_ResNet_models import CPD_ResNet
11+
from data import get_loader
12+
from utils import clip_gradient, adjust_lr
13+
14+
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--epoch', type=int, default=100, help='epoch number')
17+
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
18+
parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
19+
parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
20+
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
21+
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
22+
parser.add_argument('--decay_rate', type=float, default=0.3, help='decay rate of learning rate')
23+
parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate')
24+
opt = parser.parse_args()
25+
26+
print('Learning Rate: {} ResNet: {} Trainset: {}'.format(opt.lr, opt.is_ResNet, opt.trainset))
27+
# build models
28+
if opt.is_ResNet:
29+
model = CPD_ResNet()
30+
else:
31+
model = CPD_VGG()
32+
33+
model.cuda()
34+
params = model.parameters()
35+
optimizer = torch.optim.Adam(params, opt.lr)
36+
37+
image_root = 'path1'
38+
gt_root = 'path2'
39+
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
40+
total_step = len(train_loader)
41+
42+
CE = torch.nn.BCEWithLogitsLoss()
43+
44+
45+
def train(train_loader, model, optimizer, epoch):
46+
model.train()
47+
for i, pack in enumerate(train_loader, start=1):
48+
optimizer.zero_grad()
49+
images, gts = pack
50+
images = Variable(images)
51+
gts = Variable(gts)
52+
images = images.cuda()
53+
gts = gts.cuda()
54+
55+
atts, dets = model(images)
56+
loss1 = CE(atts, gts)
57+
loss2 = CE(dets, gts)
58+
loss = loss1 + loss2
59+
loss.backward()
60+
61+
clip_gradient(optimizer, opt.clip)
62+
optimizer.step()
63+
64+
if i % 400 == 0 or i == total_step:
65+
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
66+
format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))
67+
68+
if opt.is_ResNet:
69+
save_path = 'models/CPD_Resnet/'
70+
else:
71+
save_path = 'models/CPD_VGG/'
72+
73+
if not os.path.exists(save_path):
74+
os.makedirs(save_path)
75+
if (epoch+1) % 5 == 0:
76+
torch.save(model.state_dict(), save_path + opt.trainset + '_w.pth' + '.%d' % epoch)
77+
78+
print("Let's go!")
79+
for epoch in range(1, opt.epoch):
80+
adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
81+
train(train_loader, model, optimizer, epoch)

utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
def clip_gradient(optimizer, grad_clip):
2+
for group in optimizer.param_groups:
3+
for param in group['params']:
4+
if param.grad is not None:
5+
param.grad.data.clamp_(-grad_clip, grad_clip)
6+
7+
8+
def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9+
decay = decay_rate ** (epoch // decay_epoch)
10+
for param_group in optimizer.param_groups:
11+
param_group['lr'] *= decay

0 commit comments

Comments
 (0)