Skip to content

Commit 793dcfe

Browse files
utils to re-use centroids from ivf index
1 parent b3fff7e commit 793dcfe

File tree

1 file changed

+84
-12
lines changed

1 file changed

+84
-12
lines changed

index.go

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type Index interface {
3737
IsIVFIndex() bool
3838
SetNProbe(nprobe int32)
3939
GetNProbe() int32
40+
GetNlist() int32
4041
SetDirectMap(directMapType int) error
4142

4243
Close()
@@ -63,6 +64,9 @@ type BinaryIndex interface {
6364
SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64,
6465
minEligibleCentroids int, k int64, x []uint8, centroidDis []int32,
6566
params json.RawMessage) ([]int32, []int64, error)
67+
68+
BinaryQuantizer() BinaryIndex
69+
SetIsTrained(isTrained bool)
6670
}
6771

6872
// FloatIndex defines methods specific to float-based FAISS indexes
@@ -89,6 +93,8 @@ type FloatIndex interface {
8993
Reconstruct(key int64) (recons []float32, err error)
9094
ReconstructBatch(ids []int64, vectors []float32) ([]float32, error)
9195

96+
GetCentroids() ([]float32, error)
97+
9298
// Applicable only to IVF indexes: Returns a map where the keys
9399
// are cluster IDs and the values represent the count of input vectors that belong
94100
// to each cluster.
@@ -121,6 +127,8 @@ type FloatIndex interface {
121127
// RemoveIDs removes the vectors specified by sel from the index.
122128
// Returns the number of elements removed and error.
123129
RemoveIDs(sel *IDSelector) (int, error)
130+
131+
Quantizer() *C.FaissIndex
124132
}
125133

126134
// IndexImpl represents a float vector index
@@ -220,6 +228,50 @@ func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, c
220228
return centroidIDs, centroidDistances, nil
221229
}
222230

231+
func (idx *IndexImpl) GetNlist() int32 {
232+
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil {
233+
return C.faiss_IndexIVF_nlist(ivfIdx)
234+
}
235+
return 0
236+
}
237+
238+
func (idx *IndexImpl) GetCentroids() ([]float32, error) {
239+
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil {
240+
ivfCentroids := make([]float32, idx.D()*2000)
241+
C.faiss_IndexIVF_get_centroids(ivfIdx, (*C.float)(&ivfCentroids[0]))
242+
return ivfCentroids, nil
243+
}
244+
return nil, fmt.Errorf("index is not an IVF index")
245+
}
246+
247+
func (idx *IndexImpl) Quantizer() *C.FaissIndex {
248+
if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtrFloat()); ivfIdx != nil {
249+
return C.faiss_IndexIVF_quantizer(ivfIdx)
250+
}
251+
return nil
252+
}
253+
254+
func (idx *BinaryIndexImpl) SetIsTrained(isTrained bool) {
255+
if isTrained {
256+
C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()),
257+
C.int(1))
258+
} else {
259+
C.faiss_IndexBinaryIVF_set_is_trained((*C.FaissIndexBinaryIVF)(idx.cPtrBinary()),
260+
C.int(0))
261+
}
262+
}
263+
264+
func (idx *BinaryIndexImpl) BinaryQuantizer() BinaryIndex {
265+
if bivfIdx := C.faiss_IndexBinaryIVF_cast(idx.cPtrBinary()); bivfIdx != nil {
266+
return &BinaryIndexImpl{
267+
indexPtr: C.faiss_IndexBinaryIVF_quantizer(bivfIdx),
268+
d: idx.d,
269+
metric: idx.metric,
270+
}
271+
}
272+
return nil
273+
}
274+
223275
func (idx *BinaryIndexImpl) Size() uint64 {
224276
return 0
225277
}
@@ -244,6 +296,13 @@ func (idx *BinaryIndexImpl) IsIVFIndex() bool {
244296
return C.faiss_IndexBinaryIVF_cast(idx.indexPtr) != nil
245297
}
246298

299+
func (idx *BinaryIndexImpl) GetNlist() int32 {
300+
if ivfIdx := C.faiss_IndexBinaryIVF_cast(idx.indexPtr); ivfIdx != nil {
301+
return C.faiss_IndexBinaryIVF_nlist(ivfIdx)
302+
}
303+
return 0
304+
}
305+
247306
// Binary-specific operations
248307
func (idx *BinaryIndexImpl) TrainBinary(vectors []uint8) error {
249308
n := (len(vectors) * 8) / idx.d
@@ -387,25 +446,38 @@ func (idx *BinaryIndexImpl) SearchClustersFromIVFIndex(selector Selector,
387446
distances := make([]int32, int64(n)*k)
388447
labels := make([]int64, int64(n)*k)
389448

390-
effectiveNprobe := getNProbeFromSearchParams(searchParams)
391-
392-
eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe]
393-
centroidDis = centroidDis[:effectiveNprobe]
394-
395-
if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params(
449+
if c := C.faiss_IndexBinary_search(
396450
idx.indexPtr,
397-
(C.idx_t)(n),
451+
C.idx_t(n),
398452
(*C.uint8_t)(&x[0]),
399-
(C.idx_t)(k),
400-
(*C.idx_t)(&eligibleCentroidIDs[0]),
401-
(*C.int32_t)(&centroidDis[0]),
453+
C.idx_t(k),
402454
(*C.int32_t)(&distances[0]),
403455
(*C.idx_t)(&labels[0]),
404-
(C.int)(0),
405-
searchParams.sp); c != 0 {
456+
); c != 0 {
406457
return nil, nil, getLastError()
407458
}
408459

460+
/*
461+
effectiveNprobe := getNProbeFromSearchParams(searchParams)
462+
463+
eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe]
464+
centroidDis = centroidDis[:effectiveNprobe]
465+
466+
if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params(
467+
idx.indexPtr,
468+
(C.idx_t)(n),
469+
(*C.uint8_t)(&x[0]),
470+
(C.idx_t)(k),
471+
(*C.idx_t)(&eligibleCentroidIDs[0]),
472+
(*C.int32_t)(&centroidDis[0]),
473+
(*C.int32_t)(&distances[0]),
474+
(*C.idx_t)(&labels[0]),
475+
(C.int)(0),
476+
searchParams.sp); c != 0 {
477+
return nil, nil, getLastError()
478+
}
479+
*/
480+
409481
return distances, labels, nil
410482
}
411483

0 commit comments

Comments
 (0)