-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
106 lines (73 loc) · 2.7 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import argparse
import numpy as np
from tqdm import tqdm
import timm
import torch
import torch.nn as nn
import wandb
from torch.utils.data.dataloader import DataLoader
from dataset import pytorch_dataset, augmentations
@torch.no_grad()
def testing(model, dataloader, criterion):
model.eval()
running_loss, y_true, y_pred = [], [], []
for x, y in tqdm(dataloader):
x = x.to(args.device)
y = y.to(args.device).unsqueeze(1)
outputs = model(x)
loss = criterion(outputs, y)
running_loss.append(loss.cpu().numpy())
outputs = torch.sigmoid(outputs)
y_true.append(y.squeeze(1).cpu().int())
y_pred.append(outputs.squeeze(1).cpu())
wandb.log({'Loss': np.mean(running_loss)})
return np.mean(running_loss), torch.cat(y_true, 0), torch.cat(y_pred, 0)
def log_metrics(y_true, y_pred):
test_acc = tmf.accuracy(y_pred, y_true)
test_f1 = tmf.f1(y_pred, y_true)
test_prec = tmf.precision(y_pred, y_true)
test_rec = tmf.recall(y_pred, y_true)
test_auc = tmf.auroc(y_pred, y_true)
wandb.log({
'Accuracy': test_acc,
'F1': test_f1,
'Precision': test_prec,
'Recall': test_rec,
'ROC-AUC score': test_auc})
def log_conf_matrix(y_true, y_pred):
conf_matrix = tmf.confusion_matrix(y_pred, y_true, num_classes=2)
conf_matrix = pd.DataFrame(data=conf_matrix, columns=['A', 'B'])
cf_matrix = wandb.Table(dataframe=conf_matrix)
wandb.log({'conf_mat': cf_matrix})
# main def:
def main():
# initialize parser
parser = test_parser()
args = parser.parse_args()
# initialize w&b
wandb.init(project=args.project_name, name=args.name,
config=vars(args), group=args.group)
# initialize model:
# TO DO: CREATE A CLASS OF MODEL model = ...
# load weights:
model.load_state_dict(torch.load(args.weights_dir, map_location='cpu'))
model = model.eval().to(args.device)
# defining transforms:
transforms = augmentations.get_validation_augmentations()
# define test dataset:
test_dataset = pytorch_dataset.dataset2(
args.dataset_dir, args.test_dir, transforms)
# define data loaders:
test_dataloader = DataLoader(test_dataset, num_workers=args.workers, batch_size=args.batch_size, shuffle=False)
# set the criterion:
criterion = nn.BCEWithLogitsLoss()
# testing
test_loss, y_true, y_pred = testing(
model=model, dataloader=test_dataloader, criterion=criterion)
# calculating and logging results:
log_metrics(y_true=y_true, y_pred=y_pred)
log_conf_matrix(y_true=y_true, y_pred=y_pred)
print(f'Finished Testing with test loss = {test_loss}')
if __name__ == '__main__':
main()