This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy patheval.py
92 lines (80 loc) · 3.09 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import division
try:
import faiss
except:
pass
import numpy as np
import torch
import argparse
import os
import time
from lib.metrics import evaluate
from lib.net import Normalize
join = os.path.join
import torch.nn as nn
from lib.data import load_dataset
if __name__ == "__main__":
global args
parser = argparse.ArgumentParser()
parser.add_argument("--candidates", type=int, default=10)
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--database", choices=["bigann", "deep1b"])
parser.add_argument("--device", choices=["cpu", "cuda", "auto"],
default="auto")
parser.add_argument("--gpu", action='store_true', default=False)
parser.add_argument("--quantizer", required=True)
parser.add_argument("--size-base", type=int, default=int(1e6))
parser.add_argument("--val", action='store_false', dest='test')
parser.set_defaults(gpu=False, test=True)
args = parser.parse_args()
if args.device == "auto":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
start = time.time()
if os.path.exists(args.ckpt_path):
print("Loading net")
ckpt = torch.load(args.ckpt_path)
d = vars(args)
for k, v in vars(ckpt['args']).items():
d[k] = v
(xt, xb, xq, gt) = load_dataset(args.database, args.device, size=args.size_base, test=args.test)
dim = xb.shape[1]
dint, dout = args.dint, args.dout
net = nn.Sequential(
nn.Linear(in_features=dim, out_features=dint, bias=True),
nn.BatchNorm1d(dint),
nn.ReLU(),
nn.Linear(in_features=dint, out_features=dint, bias=True),
nn.BatchNorm1d(dint),
nn.ReLU(),
nn.Linear(in_features=dint, out_features=dout, bias=True),
Normalize()
)
net.load_state_dict(ckpt['state_dict'])
net = net.to(args.device)
net = net.eval()
elif args.ckpt_path.startswith("pca-"):
assert args.database is not None
(xt, xb, xq, gt) = load_dataset(args.database, args.device, size=args.size_base, test=args.test)
args.dim = int(args.ckpt_path[4:])
mu = np.mean(xb, axis=0, keepdims=True)
xb -= mu
xq -= mu
cov = np.dot(xb.T, xb) / xb.shape[0]
eigvals, eigvecs = np.linalg.eig(cov)
o = eigvals.argsort()[::-1]
PCA = eigvecs[:, o[:args.dim]].astype(np.float32)
xb = np.dot(xb, PCA)
xb /= np.linalg.norm(xb, axis=1, keepdims=True)
xq = np.dot(xq, PCA)
xq /= np.linalg.norm(xq, axis=1, keepdims=True)
net = nn.Sequential()
else:
print("Main argument not understood: should be the path to a net checkpoint")
import sys;sys.exit(1)
evaluate(net, xq, xb, gt, [args.quantizer], '%s,rank=%d' % (args.quantizer, 10), device=args.device, trainset=xt[:10000])