-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
31 lines (26 loc) · 1.03 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
from load_data import BertClassifier, GenerateData
from torch.utils.data import DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_path = './bert_based_checkpoint'
model = BertClassifier()
model.load_state_dict(torch.load(os.path.join(save_path, 'best.pt')))
model = model.to(device)
model.eval()
def evaluate(model, dataset):
model.eval()
test_loader = DataLoader(dataset, batch_size=128)
total_acc_test = 0
with torch.no_grad():
for test_input, test_label in test_loader:
input_id = test_input['input_ids'].squeeze(1).to(device)
mask = test_input['attention_mask'].to(device)
test_label = test_label.to(device)
output = model(input_id, mask)
acc = (output.argmax(dim=1) == test_label).sum().item()
total_acc_test += acc
print(f'Test Accuracy: {total_acc_test / len(dataset): .3f}')
if __name__=="__main__":
test_dataset = GenerateData(mode="test")
evaluate(model, test_dataset)