-
Notifications
You must be signed in to change notification settings - Fork 64
/
training.py
95 lines (79 loc) · 4.98 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from tqdm import tqdm
from utils import update_lr, Meter, cal_score
def train(params, model, optimizer, epoch, train_loader, writer=None):
model.train()
device = params['device']
loss_meter = Meter()
word_right, exp_right, length, cal_num = 0, 0, 0, 0
with tqdm(train_loader, total=len(train_loader)//params['train_parts']) as pbar:
for batch_idx, (images, image_masks, labels, label_masks) in enumerate(pbar):
images, image_masks, labels, label_masks = images.to(device), image_masks.to(
device), labels.to(device), label_masks.to(device)
batch, time = labels.shape[:2]
if not 'lr_decay' in params or params['lr_decay'] == 'cosine':
update_lr(optimizer, epoch, batch_idx, len(train_loader), params['epochs'], params['lr'])
optimizer.zero_grad()
probs, counting_preds, word_loss, counting_loss = model(images, image_masks, labels, label_masks)
loss = word_loss + counting_loss
loss.backward()
if params['gradient_clip']:
torch.nn.utils.clip_grad_norm_(model.parameters(), params['gradient'])
optimizer.step()
loss_meter.add(loss.item())
wordRate, ExpRate = cal_score(probs, labels, label_masks)
word_right = word_right + wordRate * time
exp_right = exp_right + ExpRate * batch
length = length + time
cal_num = cal_num + batch
if writer:
current_step = epoch * len(train_loader) // params['train_parts'] + batch_idx + 1
writer.add_scalar('train/word_loss', word_loss.item(), current_step)
writer.add_scalar('train/counting_loss', counting_loss.item(), current_step)
writer.add_scalar('train/loss', loss.item(), current_step)
writer.add_scalar('train/WordRate', wordRate, current_step)
writer.add_scalar('train/ExpRate', ExpRate, current_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], current_step)
pbar.set_description(f'{epoch+1} word_loss:{word_loss.item():.4f} counting_loss:{counting_loss.item():.4f} WRate:{word_right / length:.4f} '
f'ERate:{exp_right / cal_num:.4f}')
if batch_idx >= len(train_loader) // params['train_parts']:
break
if writer:
writer.add_scalar('epoch/train_loss', loss_meter.mean, epoch+1)
writer.add_scalar('epoch/train_WordRate', word_right / length, epoch+1)
writer.add_scalar('epoch/train_ExpRate', exp_right / cal_num, epoch + 1)
return loss_meter.mean, word_right / length, exp_right / cal_num
def eval(params, model, epoch, eval_loader, writer=None):
model.eval()
device = params['device']
loss_meter = Meter()
word_right, exp_right, length, cal_num = 0, 0, 0, 0
with tqdm(eval_loader, total=len(eval_loader)//params['valid_parts']) as pbar, torch.no_grad():
for batch_idx, (images, image_masks, labels, label_masks) in enumerate(pbar):
images, image_masks, labels, label_masks = images.to(device), image_masks.to(
device), labels.to(device), label_masks.to(device)
batch, time = labels.shape[:2]
probs, counting_preds, word_loss, counting_loss = model(images, image_masks, labels, label_masks, is_train=False)
loss = word_loss + counting_loss
loss_meter.add(loss.item())
wordRate, ExpRate = cal_score(probs, labels, label_masks)
word_right = word_right + wordRate * time
exp_right = exp_right + ExpRate * batch
length = length + time
cal_num = cal_num + batch
if writer:
current_step = epoch * len(eval_loader)//params['valid_parts'] + batch_idx + 1
writer.add_scalar('eval/word_loss', word_loss.item(), current_step)
writer.add_scalar('eval/counting_loss', counting_loss.item(), current_step)
writer.add_scalar('eval/loss', loss.item(), current_step)
writer.add_scalar('eval/WordRate', wordRate, current_step)
writer.add_scalar('eval/ExpRate', ExpRate, current_step)
pbar.set_description(f'{epoch+1} word_loss:{word_loss.item():.4f} counting_loss:{counting_loss.item():.4f} WRate:{word_right / length:.4f} '
f'ERate:{exp_right / cal_num:.4f}')
if batch_idx >= len(eval_loader) // params['valid_parts']:
break
if writer:
writer.add_scalar('epoch/eval_loss', loss_meter.mean, epoch + 1)
writer.add_scalar('epoch/eval_WordRate', word_right / length, epoch + 1)
writer.add_scalar('epoch/eval_ExpRate', exp_right / len(eval_loader.dataset), epoch + 1)
return loss_meter.mean, word_right / length, exp_right / cal_num