-
Notifications
You must be signed in to change notification settings - Fork 29
/
main.py
101 lines (92 loc) · 4.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from tqdm import tqdm
from model import Transformer
from config import get_config
from loss_func import CELoss, SupConLoss, DualLoss
from data_utils import load_data
from transformers import logging, AutoTokenizer, AutoModel
class Instructor:
def __init__(self, args, logger):
self.args = args
self.logger = logger
self.logger.info('> creating model {}'.format(args.model_name))
if args.model_name == 'bert':
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
base_model = AutoModel.from_pretrained('bert-base-uncased')
elif args.model_name == 'roberta':
self.tokenizer = AutoTokenizer.from_pretrained('roberta-base', add_prefix_space=True)
base_model = AutoModel.from_pretrained('roberta-base')
else:
raise ValueError('unknown model')
self.model = Transformer(base_model, args.num_classes, args.method)
self.model.to(args.device)
if args.device.type == 'cuda':
self.logger.info('> cuda memory allocated: {}'.format(torch.cuda.memory_allocated(args.device.index)))
self._print_args()
def _print_args(self):
self.logger.info('> training arguments:')
for arg in vars(self.args):
self.logger.info(f">>> {arg}: {getattr(self.args, arg)}")
def _train(self, dataloader, criterion, optimizer):
train_loss, n_correct, n_train = 0, 0, 0
self.model.train()
for inputs, targets in tqdm(dataloader, disable=self.args.backend, ascii=' >='):
inputs = {k: v.to(self.args.device) for k, v in inputs.items()}
targets = targets.to(self.args.device)
outputs = self.model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * targets.size(0)
n_correct += (torch.argmax(outputs['predicts'], -1) == targets).sum().item()
n_train += targets.size(0)
return train_loss / n_train, n_correct / n_train
def _test(self, dataloader, criterion):
test_loss, n_correct, n_test = 0, 0, 0
self.model.eval()
with torch.no_grad():
for inputs, targets in tqdm(dataloader, disable=self.args.backend, ascii=' >='):
inputs = {k: v.to(self.args.device) for k, v in inputs.items()}
targets = targets.to(self.args.device)
outputs = self.model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item() * targets.size(0)
n_correct += (torch.argmax(outputs['predicts'], -1) == targets).sum().item()
n_test += targets.size(0)
return test_loss / n_test, n_correct / n_test
def run(self):
train_dataloader, test_dataloader = load_data(dataset=self.args.dataset,
data_dir=self.args.data_dir,
tokenizer=self.tokenizer,
train_batch_size=self.args.train_batch_size,
test_batch_size=self.args.test_batch_size,
model_name=self.args.model_name,
method=self.args.method,
workers=0)
_params = filter(lambda p: p.requires_grad, self.model.parameters())
if self.args.method == 'ce':
criterion = CELoss()
elif self.args.method == 'scl':
criterion = SupConLoss(self.args.alpha, self.args.temp)
elif self.args.method == 'dualcl':
criterion = DualLoss(self.args.alpha, self.args.temp)
else:
raise ValueError('unknown method')
optimizer = torch.optim.AdamW(_params, lr=self.args.lr, weight_decay=self.args.decay)
best_loss, best_acc = 0, 0
for epoch in range(self.args.num_epoch):
train_loss, train_acc = self._train(train_dataloader, criterion, optimizer)
test_loss, test_acc = self._test(test_dataloader, criterion)
if test_acc > best_acc or (test_acc == best_acc and test_loss < best_loss):
best_acc, best_loss = test_acc, test_loss
self.logger.info('{}/{} - {:.2f}%'.format(epoch+1, self.args.num_epoch, 100*(epoch+1)/self.args.num_epoch))
self.logger.info('[train] loss: {:.4f}, acc: {:.2f}'.format(train_loss, train_acc*100))
self.logger.info('[test] loss: {:.4f}, acc: {:.2f}'.format(test_loss, test_acc*100))
self.logger.info('best loss: {:.4f}, best acc: {:.2f}'.format(best_loss, best_acc*100))
self.logger.info('log saved: {}'.format(self.args.log_name))
if __name__ == '__main__':
logging.set_verbosity_error()
args, logger = get_config()
ins = Instructor(args, logger)
ins.run()