|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from torch.autograd import Variable |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pdb, os, argparse |
| 7 | +from datetime import datetime |
| 8 | + |
| 9 | +from model.CPD_models import CPD_VGG |
| 10 | +from model.CPD_ResNet_models import CPD_ResNet |
| 11 | +from data import get_loader |
| 12 | +from utils import clip_gradient, adjust_lr |
| 13 | + |
| 14 | + |
| 15 | +parser = argparse.ArgumentParser() |
| 16 | +parser.add_argument('--epoch', type=int, default=100, help='epoch number') |
| 17 | +parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') |
| 18 | +parser.add_argument('--batchsize', type=int, default=10, help='training batch size') |
| 19 | +parser.add_argument('--trainsize', type=int, default=352, help='training dataset size') |
| 20 | +parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') |
| 21 | +parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone') |
| 22 | +parser.add_argument('--decay_rate', type=float, default=0.3, help='decay rate of learning rate') |
| 23 | +parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate') |
| 24 | +opt = parser.parse_args() |
| 25 | + |
| 26 | +print('Learning Rate: {} ResNet: {} Trainset: {}'.format(opt.lr, opt.is_ResNet, opt.trainset)) |
| 27 | +# build models |
| 28 | +if opt.is_ResNet: |
| 29 | + model = CPD_ResNet() |
| 30 | +else: |
| 31 | + model = CPD_VGG() |
| 32 | + |
| 33 | +model.cuda() |
| 34 | +params = model.parameters() |
| 35 | +optimizer = torch.optim.Adam(params, opt.lr) |
| 36 | + |
| 37 | +image_root = 'path1' |
| 38 | +gt_root = 'path2' |
| 39 | +train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize) |
| 40 | +total_step = len(train_loader) |
| 41 | + |
| 42 | +CE = torch.nn.BCEWithLogitsLoss() |
| 43 | + |
| 44 | + |
| 45 | +def train(train_loader, model, optimizer, epoch): |
| 46 | + model.train() |
| 47 | + for i, pack in enumerate(train_loader, start=1): |
| 48 | + optimizer.zero_grad() |
| 49 | + images, gts = pack |
| 50 | + images = Variable(images) |
| 51 | + gts = Variable(gts) |
| 52 | + images = images.cuda() |
| 53 | + gts = gts.cuda() |
| 54 | + |
| 55 | + atts, dets = model(images) |
| 56 | + loss1 = CE(atts, gts) |
| 57 | + loss2 = CE(dets, gts) |
| 58 | + loss = loss1 + loss2 |
| 59 | + loss.backward() |
| 60 | + |
| 61 | + clip_gradient(optimizer, opt.clip) |
| 62 | + optimizer.step() |
| 63 | + |
| 64 | + if i % 400 == 0 or i == total_step: |
| 65 | + print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'. |
| 66 | + format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data)) |
| 67 | + |
| 68 | + if opt.is_ResNet: |
| 69 | + save_path = 'models/CPD_Resnet/' |
| 70 | + else: |
| 71 | + save_path = 'models/CPD_VGG/' |
| 72 | + |
| 73 | + if not os.path.exists(save_path): |
| 74 | + os.makedirs(save_path) |
| 75 | + if (epoch+1) % 5 == 0: |
| 76 | + torch.save(model.state_dict(), save_path + opt.trainset + '_w.pth' + '.%d' % epoch) |
| 77 | + |
| 78 | +print("Let's go!") |
| 79 | +for epoch in range(1, opt.epoch): |
| 80 | + adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) |
| 81 | + train(train_loader, model, optimizer, epoch) |
0 commit comments