Skip to content

Commit b3fff7e

Browse files
bivf pre-filtering utils
1 parent bfb5436 commit b3fff7e

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

index.go

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package faiss
77
#include <faiss/c_api/IndexIVF_c.h>
88
#include <faiss/c_api/IndexIVF_c_ex.h>
99
#include <faiss/c_api/IndexBinary_c.h>
10+
#include <faiss/c_api/IndexBinaryIVF_c.h>
1011
#include <faiss/c_api/index_factory_c.h>
1112
#include <faiss/c_api/MetaIndexes_c.h>
1213
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
@@ -54,6 +55,14 @@ type BinaryIndex interface {
5455
SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ([]int32, []int64, error)
5556
SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
5657
labels []int64, err error)
58+
59+
ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error)
60+
ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) (
61+
[]int64, []int32, error)
62+
// Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs
63+
SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64,
64+
minEligibleCentroids int, k int64, x []uint8, centroidDis []int32,
65+
params json.RawMessage) ([]int32, []int64, error)
5766
}
5867

5968
// FloatIndex defines methods specific to float-based FAISS indexes
@@ -156,6 +165,61 @@ func (idx *BinaryIndexImpl) Close() {
156165
}
157166
}
158167

168+
func (idx *BinaryIndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) {
169+
if !idx.IsIVFIndex() {
170+
return nil, fmt.Errorf("index is not an IVF index")
171+
}
172+
clusterIDs := make([]int64, len(vecIDs))
173+
if c := C.faiss_get_lists_for_keys_binary(
174+
idx.indexPtr,
175+
(*C.idx_t)(unsafe.Pointer(&vecIDs[0])),
176+
(C.size_t)(len(vecIDs)),
177+
(*C.idx_t)(unsafe.Pointer(&clusterIDs[0])),
178+
); c != 0 {
179+
return nil, getLastError()
180+
}
181+
rv := make(map[int64]int64, len(vecIDs))
182+
for _, v := range clusterIDs {
183+
rv[v]++
184+
}
185+
return rv, nil
186+
}
187+
188+
func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, centroidIDs []int64) (
189+
[]int64, []int32, error) {
190+
// Selector to include only the centroids whose IDs are part of 'centroidIDs'.
191+
includeSelector, err := NewIDSelectorBatch(centroidIDs)
192+
if err != nil {
193+
return nil, nil, err
194+
}
195+
defer includeSelector.Delete()
196+
197+
params, err := NewSearchParams(idx, json.RawMessage{}, includeSelector.Get(), nil)
198+
if err != nil {
199+
return nil, nil, err
200+
}
201+
defer params.Delete()
202+
203+
// Populate these with the centroids and their distances.
204+
centroidDistances := make([]int32, len(centroidIDs))
205+
206+
n := len(x) / idx.D()
207+
208+
c := C.faiss_Search_closest_eligible_centroids_binary(
209+
idx.indexPtr,
210+
(C.idx_t)(n),
211+
(*C.uint8_t)(&x[0]),
212+
(C.idx_t)(len(centroidIDs)),
213+
(*C.int32_t)(&centroidDistances[0]),
214+
(*C.idx_t)(&centroidIDs[0]),
215+
params.sp)
216+
if c != 0 {
217+
return nil, nil, getLastError()
218+
}
219+
220+
return centroidIDs, centroidDistances, nil
221+
}
222+
159223
func (idx *BinaryIndexImpl) Size() uint64 {
160224
return 0
161225
}
@@ -263,7 +327,7 @@ func (idx *BinaryIndexImpl) Train(vectors []uint8) error {
263327
}
264328

265329
func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, labels []int64, err error) {
266-
if len(exclude) == 0 && params == nil {
330+
if len(exclude) == 0 && len(params) == 0 {
267331
return idx.SearchBinary(x, k)
268332
}
269333

@@ -302,6 +366,49 @@ func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude [
302366
return distances, labels, err
303367
}
304368

369+
func (idx *BinaryIndexImpl) SearchClustersFromIVFIndex(selector Selector,
370+
eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x []uint8,
371+
centroidDis []int32, params json.RawMessage) ([]int32, []int64, error) {
372+
tempParams := &defaultSearchParamsIVF{
373+
Nlist: len(eligibleCentroidIDs),
374+
// Have to override nprobe so that more clusters will be searched for this
375+
// query, if required.
376+
Nprobe: minEligibleCentroids,
377+
}
378+
379+
searchParams, err := NewSearchParams(idx, params, selector.Get(), tempParams)
380+
if err != nil {
381+
return nil, nil, err
382+
}
383+
defer searchParams.Delete()
384+
385+
n := (len(x) * 8) / idx.D()
386+
387+
distances := make([]int32, int64(n)*k)
388+
labels := make([]int64, int64(n)*k)
389+
390+
effectiveNprobe := getNProbeFromSearchParams(searchParams)
391+
392+
eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe]
393+
centroidDis = centroidDis[:effectiveNprobe]
394+
395+
if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params(
396+
idx.indexPtr,
397+
(C.idx_t)(n),
398+
(*C.uint8_t)(&x[0]),
399+
(C.idx_t)(k),
400+
(*C.idx_t)(&eligibleCentroidIDs[0]),
401+
(*C.int32_t)(&centroidDis[0]),
402+
(*C.int32_t)(&distances[0]),
403+
(*C.idx_t)(&labels[0]),
404+
(C.int)(0),
405+
searchParams.sp); c != 0 {
406+
return nil, nil, getLastError()
407+
}
408+
409+
return distances, labels, nil
410+
}
411+
305412
// Factory functions
306413
func IndexBinaryFactory(d int, description string, metric int) (BinaryIndex, error) {
307414
return NewBinaryIndexImpl(d, description, metric)
@@ -469,7 +576,7 @@ func (idx *IndexImpl) SearchWithIDs(queries []float32, k int64, include []int64,
469576
}
470577
defer includeSelector.Delete()
471578

472-
searchParams, err := NewSearchParams(nil, params, includeSelector.Get(), nil)
579+
searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil)
473580
if err != nil {
474581
return nil, nil, err
475582
}

index_ivf.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package faiss
66
#include <faiss/c_api/Index_c.h>
77
#include <faiss/c_api/IndexIVF_c.h>
88
#include <faiss/c_api/IndexBinary_c.h>
9+
#include <faiss/c_api/IndexBinaryIVF_c.h>
910
#include <faiss/c_api/IndexIVF_c_ex.h>
1011
*/
1112
import "C"

search_params.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector,
7171
return rv, nil
7272
}
7373

74+
if !idx.IsIVFIndex() {
75+
c := C.faiss_SearchParameters_new_with_selector(&rv.sp, sel)
76+
if c != 0 {
77+
rv.Delete()
78+
return nil, fmt.Errorf("failed to create faiss search params")
79+
}
80+
return rv, nil
81+
}
82+
7483
var nlist, nprobe, nvecs, maxCodes int
7584
var ivfParams searchParamsIVF
7685

0 commit comments

Comments
 (0)