-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
91 lines (61 loc) · 2.72 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import argparse
from torch.utils.data import DataLoader
from nets.PAAD import PAAD
from custom_dataset import InterventionDataset
from utils import model_evaluation, density_estimation, get_F1_measure
def test_all(args):
test_set = InterventionDataset(args.test_image_path, args.test_csv_path, 'test')
test_loader = DataLoader(
dataset=test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
paad = PAAD(
device=device,
freeze_features=args.freeze_features,
pretrained_file=args.pretrained_file,
horizon=args.horizon).to(device)
PATH = './nets/paad.pth'
paad.load_state_dict(torch.load(PATH))
ap = model_evaluation(test_loader, paad, device)
print("Average precision on the test set: {:.4f}".format(ap))
'''
# compute density estimation on test set
density_estimation(test_loader, paad, device, threshold=0.5)
'''
# compute F1 measure on test set
f1_measure = get_F1_measure(test_loader, paad, device, threshold=0.5)
print("F1 measure on the test set: {:.4f}".format(f1_measure))
def test_datapoint(args):
test_set = InterventionDataset(args.test_image_path, args.test_csv_path, 'test')
paad = PAAD(
device='cpu',
freeze_features=args.freeze_features,
pretrained_file=args.pretrained_file,
horizon=args.horizon)
PATH = './nets/paad.pth'
paad.load_state_dict(torch.load(PATH))
paad.eval()
with torch.no_grad():
img, pred_traj_img, lidar_scan, label = test_set[args.data_index]
img.unsqueeze_(0)
pred_traj_img.unsqueeze_(0)
lidar_scan.unsqueeze_(0)
_, _, _, pred_score = paad(img, pred_traj_img, lidar_scan)
pred_score.squeeze_()
print("The ground truth label:", list(label.numpy().round(2)))
print("The predicted score: ", list(pred_score.numpy().round(2)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset parameters
parser.add_argument("--test_batch_size", type=int, default=64)
parser.add_argument("--test_image_path", type=str, default='test_set/images_test/')
parser.add_argument("--test_csv_path", type=str, default='test_set/labeled_data_test.csv')
parser.add_argument("--data_index", type=int, default=5313)
# model parameters
parser.add_argument("--freeze_features", type=bool, default=True)
parser.add_argument("--pretrained_file", type=str,
default="nets/VisionNavNet_state_hd.pth.tar")
parser.add_argument("--horizon", type=int, default=10)
args = parser.parse_args()
#test_datapoint(args)
test_all(args)