-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmeasure-query-complexity.jl
92 lines (76 loc) · 3.22 KB
/
measure-query-complexity.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
include("io.jl")
include("laion-gold-standards.jl")
function build_hsp(dbname::AbstractString, qname::AbstractString, gold::AbstractString)
Q = loadf32(qname) |> MatrixDatabase
X = loadf32(dbname) |> MatrixDatabase
knns, dists = load(gold, "knns", "dists")
dist = NormalizedCosineDistance()
hsp_queries(dist, X, Q, knns, dists)
end
function loadf32(filename; normalize=true)
@info "loading $filename"
Q = load(filename, "emb")
@assert eltype(Q) == Float32
if normalize
for c in eachcol(Q)
normalize!(c)
end
end
Q
end
function getresname(size, qname, dbname, k)
outname = "results/$(size)/$(basename(qname))/searchgraph--k=$k--" * basename(dbname)
outname = replace(outname, ".h5" => "")
outname * ".h5"
end
function solve_queries_with_cost(size::AbstractString, qname::AbstractString; k=30, outname=nothing, logbase=1.5, minrecall=0.9, disjointbase=1.01e
dbname = "laion2B-en-clip768v2-n=$size.h5"
outname = outname === nothing ? getresname(size, qname, dbname, k) : outname
mkpath(dirname(outname))
Q = loadf32(qname) |> MatrixDatabase
db = loadf32(dbname) |> MatrixDatabase
dist = NormalizedCosineDistance()
G = SearchGraph(; dist, db)
ctx = SearchGraphContext(
neighborhood = Neighborhood(SatNeighborhood(); logbase),
hyperparameters_callback = OptimizeParameters(MinRecall(minrecall)),
hints_callback = DisjointHints(disjointbase)
)
buildtime = @elapsed index!(G, ctx)
memory = Base.summarysize(G)
n = length(Q)
knns = zeros(Int32, k, n)
dists = zeros(Float32, k, n)
searchtime = zeros(Float64, n)
cost = zeros(Int32, n)
Threads.@threads :static for i in 1:n
res = getknnresult(k, ctx)
t = @elapsed p = search(G, Q[i], res)
searchtime[i] = t
cost[i] = p.cost
k_ = length(res)
knns[1:k_, i] .= IdView(res)
dists[1:k_, i] .= DistView(res)
end
jldsave(outname; knns, dists, cost, buildtime, searchtime, memory, name="SearchGraph", params="b=$logbase r=$minrecall")
outname
end
function solve_queries(size::AbstractString, qname::AbstractString; k=30, outname=nothing, logbase=1.5, minrecall=0.95, disjointbase=1.01)
dbname = "laion2B-en-clip768v2-n=$size.h5"
outname = outname === nothing ? getresname(size, qname, dbname, k) : outname
mkpath(dirname(outname))
Q = loadf32(qname) |> MatrixDatabase
db = loadf32(dbname) |> MatrixDatabase
dist = NormalizedCosineDistance()
G = SearchGraph(; dist, db)
ctx = SearchGraphContext(
neighborhood = Neighborhood(SatNeighborhood(); logbase),
hyperparameters_callback = OptimizeParameters(MinRecall(minrecall)),
hints_callback = DisjointHints(disjointbase)
)
buildtime = @elapsed index!(G, ctx)
searchtime = @elapsed knns, dists = searchbatch(G, Q, k)
memory = Base.summarysize(G)
jldsave(outname; knns, dists, buildtime, searchtime, memory, name="SearchGraph", params="b=$logbase r=$minrecall disjoint=$disjointbase")
outname
end