@@ -37,6 +37,7 @@ type Index interface {
3737 IsIVFIndex () bool
3838 SetNProbe (nprobe int32 )
3939 GetNProbe () int32
40+ GetNlist () int32
4041 SetDirectMap (directMapType int ) error
4142
4243 Close ()
@@ -63,6 +64,9 @@ type BinaryIndex interface {
6364 SearchClustersFromIVFIndex (selector Selector , eligibleCentroidIDs []int64 ,
6465 minEligibleCentroids int , k int64 , x []uint8 , centroidDis []int32 ,
6566 params json.RawMessage ) ([]int32 , []int64 , error )
67+
68+ BinaryQuantizer () BinaryIndex
69+ SetIsTrained (isTrained bool )
6670}
6771
6872// FloatIndex defines methods specific to float-based FAISS indexes
@@ -89,6 +93,8 @@ type FloatIndex interface {
8993 Reconstruct (key int64 ) (recons []float32 , err error )
9094 ReconstructBatch (ids []int64 , vectors []float32 ) ([]float32 , error )
9195
96+ GetCentroids () ([]float32 , error )
97+
9298 // Applicable only to IVF indexes: Returns a map where the keys
9399 // are cluster IDs and the values represent the count of input vectors that belong
94100 // to each cluster.
@@ -121,6 +127,8 @@ type FloatIndex interface {
121127 // RemoveIDs removes the vectors specified by sel from the index.
122128 // Returns the number of elements removed and error.
123129 RemoveIDs (sel * IDSelector ) (int , error )
130+
131+ Quantizer () * C.FaissIndex
124132}
125133
126134// IndexImpl represents a float vector index
@@ -220,6 +228,50 @@ func (idx *BinaryIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []uint8, c
220228 return centroidIDs , centroidDistances , nil
221229}
222230
231+ func (idx * IndexImpl ) GetNlist () int32 {
232+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtrFloat ()); ivfIdx != nil {
233+ return C .faiss_IndexIVF_nlist (ivfIdx )
234+ }
235+ return 0
236+ }
237+
238+ func (idx * IndexImpl ) GetCentroids () ([]float32 , error ) {
239+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtrFloat ()); ivfIdx != nil {
240+ ivfCentroids := make ([]float32 , idx .D ()* 2000 )
241+ C .faiss_IndexIVF_get_centroids (ivfIdx , (* C .float )(& ivfCentroids [0 ]))
242+ return ivfCentroids , nil
243+ }
244+ return nil , fmt .Errorf ("index is not an IVF index" )
245+ }
246+
247+ func (idx * IndexImpl ) Quantizer () * C.FaissIndex {
248+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtrFloat ()); ivfIdx != nil {
249+ return C .faiss_IndexIVF_quantizer (ivfIdx )
250+ }
251+ return nil
252+ }
253+
254+ func (idx * BinaryIndexImpl ) SetIsTrained (isTrained bool ) {
255+ if isTrained {
256+ C .faiss_IndexBinaryIVF_set_is_trained ((* C .FaissIndexBinaryIVF )(idx .cPtrBinary ()),
257+ C .int (1 ))
258+ } else {
259+ C .faiss_IndexBinaryIVF_set_is_trained ((* C .FaissIndexBinaryIVF )(idx .cPtrBinary ()),
260+ C .int (0 ))
261+ }
262+ }
263+
264+ func (idx * BinaryIndexImpl ) BinaryQuantizer () BinaryIndex {
265+ if bivfIdx := C .faiss_IndexBinaryIVF_cast (idx .cPtrBinary ()); bivfIdx != nil {
266+ return & BinaryIndexImpl {
267+ indexPtr : C .faiss_IndexBinaryIVF_quantizer (bivfIdx ),
268+ d : idx .d ,
269+ metric : idx .metric ,
270+ }
271+ }
272+ return nil
273+ }
274+
223275func (idx * BinaryIndexImpl ) Size () uint64 {
224276 return 0
225277}
@@ -244,6 +296,13 @@ func (idx *BinaryIndexImpl) IsIVFIndex() bool {
244296 return C .faiss_IndexBinaryIVF_cast (idx .indexPtr ) != nil
245297}
246298
299+ func (idx * BinaryIndexImpl ) GetNlist () int32 {
300+ if ivfIdx := C .faiss_IndexBinaryIVF_cast (idx .indexPtr ); ivfIdx != nil {
301+ return C .faiss_IndexBinaryIVF_nlist (ivfIdx )
302+ }
303+ return 0
304+ }
305+
247306// Binary-specific operations
248307func (idx * BinaryIndexImpl ) TrainBinary (vectors []uint8 ) error {
249308 n := (len (vectors ) * 8 ) / idx .d
@@ -387,25 +446,38 @@ func (idx *BinaryIndexImpl) SearchClustersFromIVFIndex(selector Selector,
387446 distances := make ([]int32 , int64 (n )* k )
388447 labels := make ([]int64 , int64 (n )* k )
389448
390- effectiveNprobe := getNProbeFromSearchParams (searchParams )
391-
392- eligibleCentroidIDs = eligibleCentroidIDs [:effectiveNprobe ]
393- centroidDis = centroidDis [:effectiveNprobe ]
394-
395- if c := C .faiss_IndexBinaryIVF_search_preassigned_with_params (
449+ if c := C .faiss_IndexBinary_search (
396450 idx .indexPtr ,
397- ( C .idx_t ) (n ),
451+ C .idx_t (n ),
398452 (* C .uint8_t )(& x [0 ]),
399- (C .idx_t )(k ),
400- (* C .idx_t )(& eligibleCentroidIDs [0 ]),
401- (* C .int32_t )(& centroidDis [0 ]),
453+ C .idx_t (k ),
402454 (* C .int32_t )(& distances [0 ]),
403455 (* C .idx_t )(& labels [0 ]),
404- (C .int )(0 ),
405- searchParams .sp ); c != 0 {
456+ ); c != 0 {
406457 return nil , nil , getLastError ()
407458 }
408459
460+ /*
461+ effectiveNprobe := getNProbeFromSearchParams(searchParams)
462+
463+ eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe]
464+ centroidDis = centroidDis[:effectiveNprobe]
465+
466+ if c := C.faiss_IndexBinaryIVF_search_preassigned_with_params(
467+ idx.indexPtr,
468+ (C.idx_t)(n),
469+ (*C.uint8_t)(&x[0]),
470+ (C.idx_t)(k),
471+ (*C.idx_t)(&eligibleCentroidIDs[0]),
472+ (*C.int32_t)(¢roidDis[0]),
473+ (*C.int32_t)(&distances[0]),
474+ (*C.idx_t)(&labels[0]),
475+ (C.int)(0),
476+ searchParams.sp); c != 0 {
477+ return nil, nil, getLastError()
478+ }
479+ */
480+
409481 return distances , labels , nil
410482}
411483
0 commit comments