@@ -7,6 +7,7 @@ package faiss
77#include <faiss/c_api/IndexIVF_c.h>
88#include <faiss/c_api/IndexIVF_c_ex.h>
99#include <faiss/c_api/IndexBinary_c.h>
10+ #include <faiss/c_api/IndexBinaryIVF_c.h>
1011#include <faiss/c_api/index_factory_c.h>
1112#include <faiss/c_api/MetaIndexes_c.h>
1213#include <faiss/c_api/impl/AuxIndexStructures_c.h>
@@ -54,6 +55,14 @@ type BinaryIndex interface {
5455 SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 , params json.RawMessage ) ([]int32 , []int64 , error )
5556 SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
5657 labels []int64 , err error )
58+
59+ ObtainClusterVectorCountsFromIVFIndex (vecIDs []int64 ) (map [int64 ]int64 , error )
60+ ObtainClustersWithDistancesFromIVFIndex (x []uint8 , centroidIDs []int64 ) (
61+ []int64 , []int32 , error )
62+ // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs
63+ SearchClustersFromIVFIndex (selector Selector , eligibleCentroidIDs []int64 ,
64+ minEligibleCentroids int , k int64 , x []uint8 , centroidDis []int32 ,
65+ params json.RawMessage ) ([]int32 , []int64 , error )
5766}
5867
5968// FloatIndex defines methods specific to float-based FAISS indexes
@@ -156,6 +165,61 @@ func (idx *BinaryIndexImpl) Close() {
156165 }
157166}
158167
168+ func (idx * BinaryIndexImpl ) ObtainClusterVectorCountsFromIVFIndex (vecIDs []int64 ) (map [int64 ]int64 , error ) {
169+ if ! idx .IsIVFIndex () {
170+ return nil , fmt .Errorf ("index is not an IVF index" )
171+ }
172+ clusterIDs := make ([]int64 , len (vecIDs ))
173+ if c := C .faiss_get_lists_for_keys_binary (
174+ idx .indexPtr ,
175+ (* C .idx_t )(unsafe .Pointer (& vecIDs [0 ])),
176+ (C .size_t )(len (vecIDs )),
177+ (* C .idx_t )(unsafe .Pointer (& clusterIDs [0 ])),
178+ ); c != 0 {
179+ return nil , getLastError ()
180+ }
181+ rv := make (map [int64 ]int64 , len (vecIDs ))
182+ for _ , v := range clusterIDs {
183+ rv [v ]++
184+ }
185+ return rv , nil
186+ }
187+
188+ func (idx * BinaryIndexImpl ) ObtainClustersWithDistancesFromIVFIndex (x []uint8 , centroidIDs []int64 ) (
189+ []int64 , []int32 , error ) {
190+ // Selector to include only the centroids whose IDs are part of 'centroidIDs'.
191+ includeSelector , err := NewIDSelectorBatch (centroidIDs )
192+ if err != nil {
193+ return nil , nil , err
194+ }
195+ defer includeSelector .Delete ()
196+
197+ params , err := NewSearchParams (idx , json.RawMessage {}, includeSelector .Get (), nil )
198+ if err != nil {
199+ return nil , nil , err
200+ }
201+ defer params .Delete ()
202+
203+ // Populate these with the centroids and their distances.
204+ centroidDistances := make ([]int32 , len (centroidIDs ))
205+
206+ n := len (x ) / idx .D ()
207+
208+ c := C .faiss_Search_closest_eligible_centroids_binary (
209+ idx .indexPtr ,
210+ (C .idx_t )(n ),
211+ (* C .uint8_t )(& x [0 ]),
212+ (C .idx_t )(len (centroidIDs )),
213+ (* C .int32_t )(& centroidDistances [0 ]),
214+ (* C .idx_t )(& centroidIDs [0 ]),
215+ params .sp )
216+ if c != 0 {
217+ return nil , nil , getLastError ()
218+ }
219+
220+ return centroidIDs , centroidDistances , nil
221+ }
222+
159223func (idx * BinaryIndexImpl ) Size () uint64 {
160224 return 0
161225}
@@ -263,7 +327,7 @@ func (idx *BinaryIndexImpl) Train(vectors []uint8) error {
263327}
264328
265329func (idx * BinaryIndexImpl ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 , labels []int64 , err error ) {
266- if len (exclude ) == 0 && params == nil {
330+ if len (exclude ) == 0 && len ( params ) == 0 {
267331 return idx .SearchBinary (x , k )
268332 }
269333
@@ -302,6 +366,49 @@ func (idx *BinaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude [
302366 return distances , labels , err
303367}
304368
369+ func (idx * BinaryIndexImpl ) SearchClustersFromIVFIndex (selector Selector ,
370+ eligibleCentroidIDs []int64 , minEligibleCentroids int , k int64 , x []uint8 ,
371+ centroidDis []int32 , params json.RawMessage ) ([]int32 , []int64 , error ) {
372+ tempParams := & defaultSearchParamsIVF {
373+ Nlist : len (eligibleCentroidIDs ),
374+ // Have to override nprobe so that more clusters will be searched for this
375+ // query, if required.
376+ Nprobe : minEligibleCentroids ,
377+ }
378+
379+ searchParams , err := NewSearchParams (idx , params , selector .Get (), tempParams )
380+ if err != nil {
381+ return nil , nil , err
382+ }
383+ defer searchParams .Delete ()
384+
385+ n := (len (x ) * 8 ) / idx .D ()
386+
387+ distances := make ([]int32 , int64 (n )* k )
388+ labels := make ([]int64 , int64 (n )* k )
389+
390+ effectiveNprobe := getNProbeFromSearchParams (searchParams )
391+
392+ eligibleCentroidIDs = eligibleCentroidIDs [:effectiveNprobe ]
393+ centroidDis = centroidDis [:effectiveNprobe ]
394+
395+ if c := C .faiss_IndexBinaryIVF_search_preassigned_with_params (
396+ idx .indexPtr ,
397+ (C .idx_t )(n ),
398+ (* C .uint8_t )(& x [0 ]),
399+ (C .idx_t )(k ),
400+ (* C .idx_t )(& eligibleCentroidIDs [0 ]),
401+ (* C .int32_t )(& centroidDis [0 ]),
402+ (* C .int32_t )(& distances [0 ]),
403+ (* C .idx_t )(& labels [0 ]),
404+ (C .int )(0 ),
405+ searchParams .sp ); c != 0 {
406+ return nil , nil , getLastError ()
407+ }
408+
409+ return distances , labels , nil
410+ }
411+
305412// Factory functions
306413func IndexBinaryFactory (d int , description string , metric int ) (BinaryIndex , error ) {
307414 return NewBinaryIndexImpl (d , description , metric )
@@ -469,7 +576,7 @@ func (idx *IndexImpl) SearchWithIDs(queries []float32, k int64, include []int64,
469576 }
470577 defer includeSelector .Delete ()
471578
472- searchParams , err := NewSearchParams (nil , params , includeSelector .Get (), nil )
579+ searchParams , err := NewSearchParams (idx , params , includeSelector .Get (), nil )
473580 if err != nil {
474581 return nil , nil , err
475582 }
0 commit comments