diff --git a/index.go b/index.go index 3a399e5..01e2ab6 100644 --- a/index.go +++ b/index.go @@ -6,15 +6,14 @@ package faiss #include #include #include +#include #include #include #include */ import "C" import ( - "encoding/json" "fmt" - "sort" "unsafe" ) @@ -36,72 +35,15 @@ type Index interface { // MetricType returns the metric type of the index. MetricType() int - // Train trains the index on a representative set of vectors. - Train(x []float32) error - - // Add adds vectors to the index. - Add(x []float32) error - - // AddWithIDs is like Add, but stores xids instead of sequential IDs. - AddWithIDs(x []float32, xids []int64) error - // Returns true if the index is an IVF index. IsIVFIndex() bool - // Applicable only to IVF indexes: Returns a map where the keys - // are cluster IDs and the values represent the count of input vectors that belong - // to each cluster. - // This method only considers the given vecIDs and does not account for all - // vectors in the index. - // Example: - // If vecIDs = [1, 2, 3, 4, 5], and: - // - Vectors 1 and 2 belong to cluster 1 - // - Vectors 3, 4, and 5 belong to cluster 2 - // The output will be: map[1:2, 2:3] - ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) - - // Applicable only to IVF indexes: Returns the centroid IDs in decreasing order - // of proximity to query 'x' and their distance from 'x' - ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( - []int64, []float32, error) - - // Applicable only to IVF indexes: Returns the top k centroid cardinalities and - // their vectors in chosen order (descending or ascending) - ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ([]uint64, [][]float32, error) - - // Search queries the index with the vectors in x. - // Returns the IDs of the k nearest neighbors for each query vector and the - // corresponding distances. - Search(x []float32, k int64) (distances []float32, labels []int64, err error) - - SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) (distances []float32, - labels []int64, err error) - - SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, - labels []int64, err error) - - // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs - SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, - minEligibleCentroids int, k int64, x, centroidDis []float32, - params json.RawMessage) ([]float32, []int64, error) - - Reconstruct(key int64) ([]float32, error) - - ReconstructBatch(keys []int64, recons []float32) ([]float32, error) - + // MergeFrom merges another index into this index. MergeFrom(other Index, add_id int64) error - // RangeSearch queries the index with the vectors in x. - // Returns all vectors with distance < radius. - RangeSearch(x []float32, radius float32) (*RangeSearchResult, error) - // Reset removes all vectors from the index. Reset() error - // RemoveIDs removes the vectors specified by sel from the index. - // Returns the number of elements removed and error. - RemoveIDs(sel *IDSelector) (int, error) - // Close frees the memory used by the index. Close() @@ -111,340 +53,42 @@ type Index interface { cPtr() *C.FaissIndex } -type faissIndex struct { - idx *C.FaissIndex -} +type IndexType int -func (idx *faissIndex) cPtr() *C.FaissIndex { - return idx.idx -} +const ( + FloatIndexType IndexType = iota + BinaryIndexType +) -func (idx *faissIndex) Size() uint64 { - size := C.faiss_Index_size(idx.idx) - return uint64(size) +type indexImpl struct { + idx *C.FaissIndex } -func (idx *faissIndex) D() int { +func (idx *indexImpl) D() int { return int(C.faiss_Index_d(idx.idx)) } -func (idx *faissIndex) IsTrained() bool { +func (idx *indexImpl) IsTrained() bool { return C.faiss_Index_is_trained(idx.idx) != 0 } -func (idx *faissIndex) Ntotal() int64 { +func (idx *indexImpl) Ntotal() int64 { return int64(C.faiss_Index_ntotal(idx.idx)) } -func (idx *faissIndex) MetricType() int { +func (idx *indexImpl) MetricType() int { return int(C.faiss_Index_metric_type(idx.idx)) } -func (idx *faissIndex) Train(x []float32) error { - n := len(x) / idx.D() - if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { - return getLastError() - } - return nil -} - -func (idx *faissIndex) Add(x []float32) error { - n := len(x) / idx.D() - if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { - return getLastError() - } - return nil -} - -func (idx *faissIndex) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { - if !idx.IsIVFIndex() { - return nil, fmt.Errorf("index is not an IVF index") - } - clusterIDs := make([]int64, len(vecIDs)) - if c := C.faiss_get_lists_for_keys( - idx.idx, - (*C.idx_t)(unsafe.Pointer(&vecIDs[0])), - (C.size_t)(len(vecIDs)), - (*C.idx_t)(unsafe.Pointer(&clusterIDs[0])), - ); c != 0 { - return nil, getLastError() - } - rv := make(map[int64]int64, len(vecIDs)) - for _, v := range clusterIDs { - rv[v]++ - } - return rv, nil -} - -func (idx *faissIndex) IsIVFIndex() bool { +func (idx *indexImpl) IsIVFIndex() bool { if ivfIdx := C.faiss_IndexIVF_cast(idx.cPtr()); ivfIdx == nil { return false } return true } -func (idx *faissIndex) ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( - []int64, []float32, error) { - // Selector to include only the centroids whose IDs are part of 'centroidIDs'. - includeSelector, err := NewIDSelectorBatch(centroidIDs) - if err != nil { - return nil, nil, err - } - defer includeSelector.Delete() - - params, err := NewSearchParams(idx, json.RawMessage{}, includeSelector.Get(), nil) - if err != nil { - return nil, nil, err - } - defer params.Delete() - - // Populate these with the centroids and their distances. - centroids := make([]int64, len(centroidIDs)) - centroidDistances := make([]float32, len(centroidIDs)) - - n := len(x) / idx.D() - - c := C.faiss_Search_closest_eligible_centroids( - idx.idx, - (C.idx_t)(n), - (*C.float)(&x[0]), - (C.idx_t)(len(centroidIDs)), - (*C.float)(¢roidDistances[0]), - (*C.idx_t)(¢roids[0]), - params.sp) - if c != 0 { - return nil, nil, getLastError() - } - - return centroids, centroidDistances, nil -} - -func (idx *faissIndex) ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ( - []uint64, [][]float32, error) { - if limit <= 0 { - return nil, nil, nil - } - - nlist := int(C.faiss_IndexIVF_nlist(idx.idx)) - if nlist == 0 { - return nil, nil, nil - } - - centroidCardinalities := make([]C.size_t, nlist) - - // Allocate a flat buffer for all centroids, then slice it per centroid - d := idx.D() - flatCentroids := make([]float32, nlist*d) - - // Call the C function to fill centroid vectors and cardinalities - c := C.faiss_IndexIVF_get_centroids_and_cardinality( - idx.idx, - (*C.float)(&flatCentroids[0]), - (*C.size_t)(¢roidCardinalities[0]), - nil, - ) - if c != 0 { - return nil, nil, getLastError() - } - - topIndices := getIndicesOfKCentroidCardinalities( - centroidCardinalities, - min(limit, nlist), - descending) - - rvCardinalities := make([]uint64, len(topIndices)) - rvCentroids := make([][]float32, len(topIndices)) - - for i, idx := range topIndices { - rvCardinalities[i] = uint64(centroidCardinalities[idx]) - rvCentroids[i] = flatCentroids[idx*d : (idx+1)*d] - } - - return rvCardinalities, rvCentroids, nil - -} - -func getIndicesOfKCentroidCardinalities(cardinalities []C.size_t, k int, descending bool) []int { - n := len(cardinalities) - indices := make([]int, n) - for i := range indices { - indices[i] = i - } - - // Sort only the indices based on cardinality values - sort.Slice(indices, func(i, j int) bool { - if descending { - return cardinalities[indices[i]] > cardinalities[indices[j]] - } - return cardinalities[indices[i]] < cardinalities[indices[j]] - }) - if k >= n { - return indices - } - - return indices[:k] -} - -func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, - eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, - centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) { - - tempParams := &defaultSearchParamsIVF{ - Nlist: len(eligibleCentroidIDs), - // Have to override nprobe so that more clusters will be searched for this - // query, if required. - Nprobe: minEligibleCentroids, - } - - searchParams, err := NewSearchParams(idx, params, selector.Get(), tempParams) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() - - n := len(x) / idx.D() - - distances := make([]float32, int64(n)*k) - labels := make([]int64, int64(n)*k) - - effectiveNprobe := getNProbeFromSearchParams(searchParams) - eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe] - centroidDis = centroidDis[:effectiveNprobe] - - if c := C.faiss_IndexIVF_search_preassigned_with_params( - idx.idx, - (C.idx_t)(n), - (*C.float)(&x[0]), - (C.idx_t)(k), - (*C.idx_t)(&eligibleCentroidIDs[0]), - (*C.float)(¢roidDis[0]), - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - (C.int)(0), - searchParams.sp); c != 0 { - return nil, nil, getLastError() - } - - return distances, labels, nil -} - -func (idx *faissIndex) AddWithIDs(x []float32, xids []int64) error { - n := len(x) / idx.D() - if c := C.faiss_Index_add_with_ids( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - (*C.idx_t)(&xids[0]), - ); c != 0 { - return getLastError() - } - return nil -} - -func (idx *faissIndex) Search(x []float32, k int64) ( - distances []float32, labels []int64, err error, -) { - n := len(x) / idx.D() - distances = make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) - if c := C.faiss_Index_search( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.idx_t(k), - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() - } - - return -} - -func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ( - distances []float32, labels []int64, err error, -) { - if params == nil && len(exclude) == 0 { - return idx.Search(x, k) - } - - var selector *C.FaissIDSelector - if len(exclude) > 0 { - excludeSelector, err := NewIDSelectorNot(exclude) - if err != nil { - return nil, nil, err - } - selector = excludeSelector.Get() - defer excludeSelector.Delete() - } - - searchParams, err := NewSearchParams(idx, params, selector, nil) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() - - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) - - return -} - -func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, - params json.RawMessage) (distances []float32, labels []int64, err error, -) { - includeSelector, err := NewIDSelectorBatch(include) - if err != nil { - return nil, nil, err - } - defer includeSelector.Delete() - - searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) - if err != nil { - return nil, nil, err - } - defer searchParams.Delete() - - distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) - return -} - -func (idx *faissIndex) Reconstruct(key int64) (recons []float32, err error) { - rv := make([]float32, idx.D()) - if c := C.faiss_Index_reconstruct( - idx.idx, - C.idx_t(key), - (*C.float)(&rv[0]), - ); c != 0 { - err = getLastError() - } - - return rv, err -} - -func (idx *faissIndex) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { - var err error - n := int64(len(keys)) - if c := C.faiss_Index_reconstruct_batch( - idx.idx, - C.idx_t(n), - (*C.idx_t)(&keys[0]), - (*C.float)(&recons[0]), - ); c != 0 { - err = getLastError() - } - - return recons, err -} - -func (i *IndexImpl) MergeFrom(other Index, add_id int64) error { - if impl, ok := other.(*IndexImpl); ok { - return i.Index.MergeFrom(impl.Index, add_id) - } - return fmt.Errorf("merge not support") -} - -func (idx *faissIndex) MergeFrom(other Index, add_id int64) (err error) { - otherIdx, ok := other.(*faissIndex) +func (idx *indexImpl) MergeFrom(other Index, add_id int64) (err error) { + otherIdx, ok := other.(*indexImpl) if !ok { return fmt.Errorf("merge api not supported") } @@ -460,122 +104,57 @@ func (idx *faissIndex) MergeFrom(other Index, add_id int64) (err error) { return err } -func (idx *faissIndex) RangeSearch(x []float32, radius float32) ( - *RangeSearchResult, error, -) { - n := len(x) / idx.D() - var rsr *C.FaissRangeSearchResult - if c := C.faiss_RangeSearchResult_new(&rsr, C.idx_t(n)); c != 0 { - return nil, getLastError() - } - if c := C.faiss_Index_range_search( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.float(radius), - rsr, - ); c != 0 { - return nil, getLastError() - } - return &RangeSearchResult{rsr}, nil -} - -func (idx *faissIndex) Reset() error { +func (idx *indexImpl) Reset() error { if c := C.faiss_Index_reset(idx.idx); c != 0 { return getLastError() } return nil } -func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) { - var nRemoved C.size_t - if c := C.faiss_Index_remove_ids(idx.idx, sel.sel, &nRemoved); c != 0 { - return 0, getLastError() - } - return int(nRemoved), nil -} - -func (idx *faissIndex) Close() { +func (idx *indexImpl) Close() { C.faiss_Index_free(idx.idx) } -func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( - distances []float32, labels []int64, err error, -) { - n := len(x) / idx.D() - distances = make([]float32, int64(n)*k) - labels = make([]int64, int64(n)*k) - - if c := C.faiss_Index_search_with_params( - idx.idx, - C.idx_t(n), - (*C.float)(&x[0]), - C.idx_t(k), - searchParams, - (*C.float)(&distances[0]), - (*C.idx_t)(&labels[0]), - ); c != 0 { - err = getLastError() - } - - return -} - -// ----------------------------------------------------------------------------- - -// RangeSearchResult is the result of a range search. -type RangeSearchResult struct { - rsr *C.FaissRangeSearchResult -} - -// Nq returns the number of queries. -func (r *RangeSearchResult) Nq() int { - return int(C.faiss_RangeSearchResult_nq(r.rsr)) -} - -// Lims returns a slice containing start and end indices for queries in the -// distances and labels slices returned by Labels. -func (r *RangeSearchResult) Lims() []int { - var lims *C.size_t - C.faiss_RangeSearchResult_lims(r.rsr, &lims) - length := r.Nq() + 1 - return (*[1 << 30]int)(unsafe.Pointer(lims))[:length:length] -} - -// Labels returns the unsorted IDs and respective distances for each query. -// The result for query i is labels[lims[i]:lims[i+1]]. -func (r *RangeSearchResult) Labels() (labels []int64, distances []float32) { - lims := r.Lims() - length := lims[len(lims)-1] - var clabels *C.idx_t - var cdist *C.float - C.faiss_RangeSearchResult_labels(r.rsr, &clabels, &cdist) - labels = (*[1 << 30]int64)(unsafe.Pointer(clabels))[:length:length] - distances = (*[1 << 30]float32)(unsafe.Pointer(cdist))[:length:length] - return +func (idx *indexImpl) Size() uint64 { + size := C.faiss_Index_size(idx.idx) + return uint64(size) } -// Delete frees the memory associated with r. -func (r *RangeSearchResult) Delete() { - C.faiss_RangeSearchResult_free(r.rsr) +func (idx *indexImpl) cPtr() *C.FaissIndex { + return idx.idx } -// IndexImpl is an abstract structure for an index. -type IndexImpl struct { - Index -} +// ----------------------------------------------------------------------------- // IndexFactory builds a composite index. // description is a comma-separated list of components. -func IndexFactory(d int, description string, metric int) (*IndexImpl, error) { - cdesc := C.CString(description) - defer C.free(unsafe.Pointer(cdesc)) - var idx faissIndex - c := C.faiss_index_factory(&idx.idx, C.int(d), cdesc, C.FaissMetricType(metric)) - if c != 0 { - return nil, getLastError() +func IndexFactory(d int, description string, metric int, indexType IndexType) (Index, error) { + + var cDescription *C.char + if description != "" { + cDescription = C.CString(description) + defer C.free(unsafe.Pointer(cDescription)) + } + + var rv Index + switch indexType { + case FloatIndexType: + var idx floatIndexImpl + c := C.faiss_index_factory(&idx.idx, C.int(d), cDescription, C.FaissMetricType(metric)) + if c != 0 { + return nil, getLastError() + } + rv = &idx + case BinaryIndexType: + var idx binaryIndexImpl + if c := C.faiss_index_binary_factory(&idx.bIdx, C.int(d), cDescription); c != 0 { + return nil, getLastError() + } + idx.idx = idx.castIndex() + rv = &idx } - return &IndexImpl{&idx}, nil + + return rv, nil } func SetOMPThreads(n uint) { diff --git a/index_binary.go b/index_binary.go new file mode 100644 index 0000000..6d7a68b --- /dev/null +++ b/index_binary.go @@ -0,0 +1,180 @@ +package faiss + +/* +#include +#include +#include +#include +#include +*/ +import "C" +import ( + "encoding/json" + "fmt" + "unsafe" +) + +type BinaryIndex interface { + Index + + bPtr() *C.FaissIndexBinary + + SetDirectMap(mapType int) error + SetNProbe(nprobe int32) + + Train(x []uint8) error + AddWithIDs(x []uint8, ids []int64) error + SearchBinary(x []uint8, k int64) ([]int32, []int64, error) + SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) ( + []int32, []int64, error) + SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, + params json.RawMessage) ([]int32, []int64, error) +} + +type binaryIndexImpl struct { + indexImpl + bIdx *C.FaissIndexBinary +} + +func (idx *binaryIndexImpl) bPtr() *C.FaissIndexBinary { + return idx.bIdx +} + +func (idx *binaryIndexImpl) SetDirectMap(mapType int) (err error) { + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.bPtr()) + // If we have a binary IVF index + if ivfPtrBinary != nil { + if c := C.faiss_IndexBinaryIVF_set_direct_map( + ivfPtrBinary, + C.int(mapType), + ); c != 0 { + err = getLastError() + } + return err + } + + return fmt.Errorf("unable to set direct map") +} + +func (idx *binaryIndexImpl) SetNProbe(nprobe int32) { + ivfPtrBinary := C.faiss_IndexBinaryIVF_cast(idx.bPtr()) + if ivfPtrBinary == nil { + return + } + C.faiss_IndexBinaryIVF_set_nprobe(idx.bIdx, C.size_t(nprobe)) +} + +func (idx *binaryIndexImpl) Train(x []uint8) error { + n := (len(x) * 8) / idx.D() + if c := C.faiss_IndexBinary_train(idx.bIdx, C.idx_t(n), + (*C.uint8_t)(&x[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *binaryIndexImpl) AddWithIDs(x []uint8, ids []int64) error { + n := (len(x) * 8) / idx.D() + if c := C.faiss_IndexBinary_add_with_ids(idx.bIdx, C.idx_t(n), + (*C.uint8_t)(&x[0]), (*C.idx_t)(&ids[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *binaryIndexImpl) SearchBinary(x []uint8, k int64) ( + []int32, []int64, error) { + nq := (len(x) * 8) / idx.D() + distances := make([]int32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + if c := C.faiss_IndexBinary_search( + idx.bIdx, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + return nil, nil, getLastError() + } + return distances, labels, nil +} + +func (idx *binaryIndexImpl) SearchBinaryWithIDs(x []uint8, k int64, include []int64, + params json.RawMessage) ([]int32, []int64, error) { + nq := (len(x) * 8) / idx.D() + distances := make([]int32, int64(nq)*k) + labels := make([]int64, int64(nq)*k) + + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( + idx.bIdx, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + searchParams.sp, + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + return nil, nil, getLastError() + } + return distances, labels, nil +} + +func (idx *binaryIndexImpl) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, + params json.RawMessage) (distances []int32, labels []int64, err error) { + if len(exclude) == 0 && len(params) == 0 { + return idx.SearchBinary(x, k) + } + + nq := (len(x) * 8) / idx.D() + distances = make([]int32, int64(nq)*k) + labels = make([]int64, int64(nq)*k) + + var selector *C.FaissIDSelector + if len(exclude) > 0 { + excludeSelector, err := NewIDSelectorNot(exclude) + if err != nil { + return nil, nil, err + } + selector = excludeSelector.Get() + defer excludeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + if c := C.faiss_IndexBinary_search_with_params( + idx.bIdx, + C.idx_t(nq), + (*C.uint8_t)(&x[0]), + C.idx_t(k), + searchParams.sp, + (*C.int32_t)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return distances, labels, err +} + +// Converts C.FaissIndexBinary to C.FaissIndex and returns pointer +func (idx *binaryIndexImpl) castIndex() *C.FaissIndex { + return (*C.FaissIndex)(unsafe.Pointer(idx.bIdx)) +} diff --git a/index_flat.go b/index_flat.go index b8a3c03..739c798 100644 --- a/index_flat.go +++ b/index_flat.go @@ -9,13 +9,21 @@ import "unsafe" // IndexFlat is an index that stores the full vectors and performs exhaustive // search. -type IndexFlat struct { - Index + +type FlatIndex interface { + FloatIndex + + Xb() []float32 + AsFlat() *C.FaissIndexFlat +} + +type flatIndexImpl struct { + floatIndexImpl } // NewIndexFlat creates a new flat index. -func NewIndexFlat(d int, metric int) (*IndexFlat, error) { - var idx faissIndex +func NewIndexFlat(d int, metric int) (FlatIndex, error) { + var idx flatIndexImpl if c := C.faiss_IndexFlat_new_with( &idx.idx, C.idx_t(d), @@ -23,22 +31,22 @@ func NewIndexFlat(d int, metric int) (*IndexFlat, error) { ); c != 0 { return nil, getLastError() } - return &IndexFlat{&idx}, nil + return &idx, nil } // NewIndexFlatIP creates a new flat index with the inner product metric type. -func NewIndexFlatIP(d int) (*IndexFlat, error) { +func NewIndexFlatIP(d int) (FlatIndex, error) { return NewIndexFlat(d, MetricInnerProduct) } // NewIndexFlatL2 creates a new flat index with the L2 metric type. -func NewIndexFlatL2(d int) (*IndexFlat, error) { +func NewIndexFlatL2(d int) (FlatIndex, error) { return NewIndexFlat(d, MetricL2) } // Xb returns the index's vectors. // The returned slice becomes invalid after any add or remove operation. -func (idx *IndexFlat) Xb() []float32 { +func (idx *flatIndexImpl) Xb() []float32 { var size C.size_t var ptr *C.float C.faiss_IndexFlat_xb(idx.cPtr(), &ptr, &size) @@ -47,10 +55,10 @@ func (idx *IndexFlat) Xb() []float32 { // AsFlat casts idx to a flat index. // AsFlat panics if idx is not a flat index. -func (idx *IndexImpl) AsFlat() *IndexFlat { +func (idx *flatIndexImpl) AsFlat() *C.FaissIndexFlat { ptr := C.faiss_IndexFlat_cast(idx.cPtr()) if ptr == nil { panic("index is not a flat index") } - return &IndexFlat{&faissIndex{ptr}} + return ptr } diff --git a/index_float.go b/index_float.go new file mode 100644 index 0000000..c592024 --- /dev/null +++ b/index_float.go @@ -0,0 +1,516 @@ +package faiss + +/* +#include +#include +#include +#include +#include +#include +#include +#include +*/ +import "C" +import ( + "encoding/json" + "fmt" + "sort" + "unsafe" +) + +// TODO-LIKITH some of the includes may not be necessary here + +// TODO-LIKITH check all of the below interfaces to figure out exactly why some of them do not have a binary counterpart +type FloatIndex interface { + Index + + // Train trains the index on a representative set of vectors. + Train(x []float32) error + + // Add adds vectors to the index. + Add(x []float32) error + + // AddWithIDs is like Add, but stores xids instead of sequential IDs. + AddWithIDs(x []float32, xids []int64) error + + // Applicable only to IVF indexes: Returns a map where the keys + // are cluster IDs and the values represent the count of input vectors that belong + // to each cluster. + // This method only considers the given vecIDs and does not account for all + // vectors in the index. + // Example: + // If vecIDs = [1, 2, 3, 4, 5], and: + // - Vectors 1 and 2 belong to cluster 1 + // - Vectors 3, 4, and 5 belong to cluster 2 + // The output will be: map[1:2, 2:3] + ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) + + // Applicable only to IVF indexes: Returns the centroid IDs in decreasing order + // of proximity to query 'x' and their distance from 'x' + ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( + []int64, []float32, error) + + // Applicable only to IVF indexes: Returns the top k centroid cardinalities and + // their vectors in chosen order (descending or ascending) + ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ([]uint64, [][]float32, error) + + // Search queries the index with the vectors in x. + // Returns the IDs of the k nearest neighbors for each query vector and the + // corresponding distances. + Search(x []float32, k int64) (distances []float32, labels []int64, err error) + + SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) (distances []float32, + labels []int64, err error) + + SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, + labels []int64, err error) + + // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs + SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, + minEligibleCentroids int, k int64, x, centroidDis []float32, + params json.RawMessage) ([]float32, []int64, error) + + Reconstruct(key int64) ([]float32, error) + + ReconstructBatch(keys []int64, recons []float32) ([]float32, error) + + MergeFrom(other Index, add_id int64) error + + // RemoveIDs removes the vectors specified by sel from the index. + // Returns the number of elements removed and error. + RemoveIDs(sel *IDSelector) (int, error) + + // RangeSearch queries the index with the vectors in x. + // Returns all vectors with distance < radius. + RangeSearch(x []float32, radius float32) (*RangeSearchResult, error) + SetNProbe(nprobe int32) + GetNProbe() int32 + + DistCompute(queryData []float32, ids []int64, k int, distances []float32) error +} + +type floatIndexImpl struct { + indexImpl +} + +func (idx *floatIndexImpl) Train(x []float32) error { + n := len(x) / idx.D() + if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *floatIndexImpl) Add(x []float32) error { + n := len(x) / idx.D() + if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { + return getLastError() + } + return nil +} + +func (idx *floatIndexImpl) AddWithIDs(x []float32, xids []int64) error { + n := len(x) / idx.D() + if c := C.faiss_Index_add_with_ids( + idx.idx, + C.idx_t(n), + (*C.float)(&x[0]), + (*C.idx_t)(&xids[0]), + ); c != 0 { + return getLastError() + } + return nil +} + +func (idx *floatIndexImpl) ObtainClusterVectorCountsFromIVFIndex(vecIDs []int64) (map[int64]int64, error) { + if !idx.IsIVFIndex() { + return nil, fmt.Errorf("index is not an IVF index") + } + clusterIDs := make([]int64, len(vecIDs)) + if c := C.faiss_get_lists_for_keys( + idx.idx, + (*C.idx_t)(unsafe.Pointer(&vecIDs[0])), + (C.size_t)(len(vecIDs)), + (*C.idx_t)(unsafe.Pointer(&clusterIDs[0])), + ); c != 0 { + return nil, getLastError() + } + rv := make(map[int64]int64, len(vecIDs)) + for _, v := range clusterIDs { + rv[v]++ + } + return rv, nil +} + +func (idx *floatIndexImpl) ObtainClustersWithDistancesFromIVFIndex(x []float32, centroidIDs []int64) ( + []int64, []float32, error) { + // Selector to include only the centroids whose IDs are part of 'centroidIDs'. + includeSelector, err := NewIDSelectorBatch(centroidIDs) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + params, err := NewSearchParams(idx, json.RawMessage{}, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer params.Delete() + + // Populate these with the centroids and their distances. + centroids := make([]int64, len(centroidIDs)) + centroidDistances := make([]float32, len(centroidIDs)) + + n := len(x) / idx.D() + + c := C.faiss_Search_closest_eligible_centroids( + idx.idx, + (C.idx_t)(n), + (*C.float)(&x[0]), + (C.idx_t)(len(centroidIDs)), + (*C.float)(¢roidDistances[0]), + (*C.idx_t)(¢roids[0]), + params.sp) + if c != 0 { + return nil, nil, getLastError() + } + + return centroids, centroidDistances, nil +} + +func getIndicesOfKCentroidCardinalities(cardinalities []C.size_t, k int, descending bool) []int { + n := len(cardinalities) + indices := make([]int, n) + for i := range indices { + indices[i] = i + } + + // Sort only the indices based on cardinality values + sort.Slice(indices, func(i, j int) bool { + if descending { + return cardinalities[indices[i]] > cardinalities[indices[j]] + } + return cardinalities[indices[i]] < cardinalities[indices[j]] + }) + if k >= n { + return indices + } + + return indices[:k] +} + +func (idx *floatIndexImpl) ObtainKCentroidCardinalitiesFromIVFIndex(limit int, descending bool) ( + []uint64, [][]float32, error) { + if limit <= 0 { + return nil, nil, nil + } + + nlist := int(C.faiss_IndexIVF_nlist(idx.idx)) + if nlist == 0 { + return nil, nil, nil + } + + centroidCardinalities := make([]C.size_t, nlist) + + // Allocate a flat buffer for all centroids, then slice it per centroid + d := idx.D() + flatCentroids := make([]float32, nlist*d) + + // Call the C function to fill centroid vectors and cardinalities + c := C.faiss_IndexIVF_get_centroids_and_cardinality( + idx.idx, + (*C.float)(&flatCentroids[0]), + (*C.size_t)(¢roidCardinalities[0]), + nil, + ) + if c != 0 { + return nil, nil, getLastError() + } + + topIndices := getIndicesOfKCentroidCardinalities( + centroidCardinalities, + min(limit, nlist), + descending) + + rvCardinalities := make([]uint64, len(topIndices)) + rvCentroids := make([][]float32, len(topIndices)) + + for i, idx := range topIndices { + rvCardinalities[i] = uint64(centroidCardinalities[idx]) + rvCentroids[i] = flatCentroids[idx*d : (idx+1)*d] + } + + return rvCardinalities, rvCentroids, nil + +} + +func (idx *floatIndexImpl) Search(x []float32, k int64) ( + distances []float32, labels []int64, err error, +) { + n := len(x) / idx.D() + distances = make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) + if c := C.faiss_Index_search( + idx.idx, + C.idx_t(n), + (*C.float)(&x[0]), + C.idx_t(k), + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return +} + +func (idx *floatIndexImpl) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( + distances []float32, labels []int64, err error, +) { + n := len(x) / idx.D() + distances = make([]float32, int64(n)*k) + labels = make([]int64, int64(n)*k) + + if c := C.faiss_Index_search_with_params( + idx.idx, + C.idx_t(n), + (*C.float)(&x[0]), + C.idx_t(k), + searchParams, + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), + ); c != 0 { + err = getLastError() + } + + return +} + +func (idx *floatIndexImpl) SearchWithoutIDs(x []float32, k int64, exclude []int64, params json.RawMessage) ( + distances []float32, labels []int64, err error, +) { + if params == nil && len(exclude) == 0 { + return idx.Search(x, k) + } + + var selector *C.FaissIDSelector + if len(exclude) > 0 { + excludeSelector, err := NewIDSelectorNot(exclude) + if err != nil { + return nil, nil, err + } + selector = excludeSelector.Get() + defer excludeSelector.Delete() + } + + searchParams, err := NewSearchParams(idx, params, selector, nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) + + return +} + +func (idx *floatIndexImpl) SearchWithIDs(x []float32, k int64, include []int64, + params json.RawMessage) (distances []float32, labels []int64, err error, +) { + includeSelector, err := NewIDSelectorBatch(include) + if err != nil { + return nil, nil, err + } + defer includeSelector.Delete() + + searchParams, err := NewSearchParams(idx, params, includeSelector.Get(), nil) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + distances, labels, err = idx.searchWithParams(x, k, searchParams.sp) + return +} + +func (idx *floatIndexImpl) SearchClustersFromIVFIndex(selector Selector, + eligibleCentroidIDs []int64, minEligibleCentroids int, k int64, x, + centroidDis []float32, params json.RawMessage) ([]float32, []int64, error) { + + tempParams := &defaultSearchParamsIVF{ + Nlist: len(eligibleCentroidIDs), + // Have to override nprobe so that more clusters will be searched for this + // query, if required. + Nprobe: minEligibleCentroids, + } + + searchParams, err := NewSearchParams(idx, params, selector.Get(), tempParams) + if err != nil { + return nil, nil, err + } + defer searchParams.Delete() + + n := len(x) / idx.D() + + distances := make([]float32, int64(n)*k) + labels := make([]int64, int64(n)*k) + + effectiveNprobe := getNProbeFromSearchParams(searchParams) + eligibleCentroidIDs = eligibleCentroidIDs[:effectiveNprobe] + centroidDis = centroidDis[:effectiveNprobe] + + if c := C.faiss_IndexIVF_search_preassigned_with_params( + idx.idx, + (C.idx_t)(n), + (*C.float)(&x[0]), + (C.idx_t)(k), + (*C.idx_t)(&eligibleCentroidIDs[0]), + (*C.float)(¢roidDis[0]), + (*C.float)(&distances[0]), + (*C.idx_t)(&labels[0]), + (C.int)(0), + searchParams.sp); c != 0 { + return nil, nil, getLastError() + } + + return distances, labels, nil +} + +func (idx *floatIndexImpl) Reconstruct(key int64) (recons []float32, err error) { + rv := make([]float32, idx.D()) + if c := C.faiss_Index_reconstruct( + idx.idx, + C.idx_t(key), + (*C.float)(&rv[0]), + ); c != 0 { + err = getLastError() + } + + return rv, err +} + +func (idx *floatIndexImpl) ReconstructBatch(keys []int64, recons []float32) ([]float32, error) { + var err error + n := int64(len(keys)) + if c := C.faiss_Index_reconstruct_batch( + idx.idx, + C.idx_t(n), + (*C.idx_t)(&keys[0]), + (*C.float)(&recons[0]), + ); c != 0 { + err = getLastError() + } + + return recons, err +} + +func (idx *floatIndexImpl) MergeFrom(other Index, add_id int64) (err error) { + otherIdx, ok := other.(*floatIndexImpl) + if !ok { + return fmt.Errorf("merge api not supported") + } + + if c := C.faiss_Index_merge_from( + idx.idx, + otherIdx.idx, + (C.idx_t)(add_id), + ); c != 0 { + err = getLastError() + } + + return err +} + +func (idx *floatIndexImpl) RemoveIDs(sel *IDSelector) (int, error) { + var nRemoved C.size_t + if c := C.faiss_Index_remove_ids(idx.idx, sel.sel, &nRemoved); c != 0 { + return 0, getLastError() + } + return int(nRemoved), nil +} + +func (idx *floatIndexImpl) RangeSearch(x []float32, radius float32) ( + *RangeSearchResult, error, +) { + n := len(x) / idx.D() + var rsr *C.FaissRangeSearchResult + if c := C.faiss_RangeSearchResult_new(&rsr, C.idx_t(n)); c != 0 { + return nil, getLastError() + } + if c := C.faiss_Index_range_search( + idx.idx, + C.idx_t(n), + (*C.float)(&x[0]), + C.float(radius), + rsr, + ); c != 0 { + return nil, getLastError() + } + return &RangeSearchResult{rsr}, nil +} + +// pass nprobe to be set as index time option for IVF indexes only. +// varying nprobe impacts recall but with an increase in latency. +func (idx *floatIndexImpl) SetNProbe(nprobe int32) { + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) + if ivfPtr == nil { + return + } + C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) +} + +func (idx *floatIndexImpl) GetNProbe() int32 { + ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) + if ivfPtr == nil { + return 0 + } + return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) +} + +func (idx *floatIndexImpl) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error { + if c := C.faiss_Index_dist_compute(idx.idx, (*C.float)(&queryData[0]), + (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 { + return getLastError() + } + + return nil +} + +// ---------------------------------------------------------------- + +// RangeSearchResult is the result of a range search. +type RangeSearchResult struct { + rsr *C.FaissRangeSearchResult +} + +// Nq returns the number of queries. +func (r *RangeSearchResult) Nq() int { + return int(C.faiss_RangeSearchResult_nq(r.rsr)) +} + +// Lims returns a slice containing start and end indices for queries in the +// distances and labels slices returned by Labels. +func (r *RangeSearchResult) Lims() []int { + var lims *C.size_t + C.faiss_RangeSearchResult_lims(r.rsr, &lims) + length := r.Nq() + 1 + return (*[1 << 30]int)(unsafe.Pointer(lims))[:length:length] +} + +// Labels returns the unsorted IDs and respective distances for each query. +// The result for query i is labels[lims[i]:lims[i+1]]. +func (r *RangeSearchResult) Labels() (labels []int64, distances []float32) { + lims := r.Lims() + length := lims[len(lims)-1] + var clabels *C.idx_t + var cdist *C.float + C.faiss_RangeSearchResult_labels(r.rsr, &clabels, &cdist) + labels = (*[1 << 30]int64)(unsafe.Pointer(clabels))[:length:length] + distances = (*[1 << 30]float32)(unsafe.Pointer(cdist))[:length:length] + return +} + +// Delete frees the memory associated with r. +func (r *RangeSearchResult) Delete() { + C.faiss_RangeSearchResult_free(r.rsr) +} diff --git a/index_io.go b/index_io.go index 608f4d7..edbe5c7 100644 --- a/index_io.go +++ b/index_io.go @@ -8,9 +8,17 @@ package faiss */ import "C" import ( + "fmt" "unsafe" ) +const ( + IOFlagMmap = C.FAISS_IO_FLAG_MMAP + IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY + IOFlagReadMmap = C.FAISS_IO_FLAG_READ_MMAP | C.FAISS_IO_FLAG_ONDISK_IVF + IOFlagSkipPrefetch = C.FAISS_IO_FLAG_SKIP_PREFETCH +) + // WriteIndex writes an index to a file. func WriteIndex(idx Index, filename string) error { cfname := C.CString(filename) @@ -21,18 +29,47 @@ func WriteIndex(idx Index, filename string) error { return nil } -func WriteIndexIntoBuffer(idx Index) ([]byte, error) { +// ReadIndex reads an index from a file. +func ReadIndex(filename string, ioflags int) (Index, error) { + cfname := C.CString(filename) + defer C.free(unsafe.Pointer(cfname)) + var idx indexImpl + if c := C.faiss_read_index_fname(cfname, C.int(ioflags), &idx.idx); c != 0 { + return nil, getLastError() + } + return &idx, nil +} + +func WriteIndexIntoBuffer(idx Index, typ IndexType) ([]byte, error) { // the values to be returned by the faiss APIs tempBuf := (*C.uchar)(nil) bufSize := C.size_t(0) - if c := C.faiss_write_index_buf( - idx.cPtr(), - &bufSize, - &tempBuf, - ); c != 0 { - C.faiss_free_buf(&tempBuf) - return nil, getLastError() + switch typ { + case FloatIndexType: + if c := C.faiss_write_index_buf( + idx.cPtr(), + &bufSize, + &tempBuf, + ); c != 0 { + C.faiss_free_buf(&tempBuf) + return nil, getLastError() + } + case BinaryIndexType: + bIdx, ok := idx.(BinaryIndex) + if !ok { + return nil, fmt.Errorf("failed to get binary index pointer") + } + if c := C.faiss_write_index_binary_buf( + bIdx.bPtr(), + &bufSize, + &tempBuf, + ); c != 0 { + C.faiss_free_buf(&tempBuf) + return nil, getLastError() + } + default: + return nil, fmt.Errorf("unsupported index type for writing to buffer") } // at this point, the idx has a valid ref count. furthermore, the index is @@ -79,17 +116,31 @@ func WriteIndexIntoBuffer(idx Index) ([]byte, error) { return rv, nil } -func ReadIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { +func ReadIndexFromBuffer(buf []byte, ioflags int, typ IndexType) (Index, error) { ptr := (*C.uchar)(unsafe.Pointer(&buf[0])) size := C.size_t(len(buf)) - // the idx var has C.FaissIndex within the struct which is nil as of now. - var idx faissIndex - if c := C.faiss_read_index_buf(ptr, - size, - C.int(ioflags), - &idx.idx); c != 0 { - return nil, getLastError() + var rv Index + switch typ { + case FloatIndexType: + // the idx var has C.FaissIndex within the struct which is nil as of now. + var idx floatIndexImpl + if c := C.faiss_read_index_buf(ptr, + size, + C.int(ioflags), + &idx.idx); c != 0 { + return nil, getLastError() + } + rv = &idx + case BinaryIndexType: + var bIdx binaryIndexImpl + if c := C.faiss_read_index_binary_buf(ptr, + size, + C.int(ioflags), + &bIdx.bIdx); c != 0 { + return nil, getLastError() + } + rv = &bIdx } ptr = nil @@ -98,23 +149,5 @@ func ReadIndexFromBuffer(buf []byte, ioflags int) (*IndexImpl, error) { // for the freshly created faiss::index becomes 1 (held by idx.idx of type C.FaissIndex) // this is allocated on the C heap, so not available for golang's GC. hence needs // to be cleaned up after the index is longer being used - to be done at zap layer. - return &IndexImpl{&idx}, nil -} - -const ( - IOFlagMmap = C.FAISS_IO_FLAG_MMAP - IOFlagReadOnly = C.FAISS_IO_FLAG_READ_ONLY - IOFlagReadMmap = C.FAISS_IO_FLAG_READ_MMAP | C.FAISS_IO_FLAG_ONDISK_IVF - IOFlagSkipPrefetch = C.FAISS_IO_FLAG_SKIP_PREFETCH -) - -// ReadIndex reads an index from a file. -func ReadIndex(filename string, ioflags int) (*IndexImpl, error) { - cfname := C.CString(filename) - defer C.free(unsafe.Pointer(cfname)) - var idx faissIndex - if c := C.faiss_read_index_fname(cfname, C.int(ioflags), &idx.idx); c != 0 { - return nil, getLastError() - } - return &IndexImpl{&idx}, nil + return rv, nil } diff --git a/index_ivf.go b/index_ivf.go index 38f023a..f96d851 100644 --- a/index_ivf.go +++ b/index_ivf.go @@ -12,7 +12,18 @@ import ( "fmt" ) -func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { +type IVFIndex interface { + FloatIndex + + SetDirectMap(mapType int) error + GetSubIndex() (Index, error) +} + +type ivfIndexImpl struct { + floatIndexImpl +} + +func (idx *ivfIndexImpl) SetDirectMap(mapType int) (err error) { ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) if ivfPtr == nil { @@ -27,7 +38,7 @@ func (idx *IndexImpl) SetDirectMap(mapType int) (err error) { return err } -func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { +func (idx *ivfIndexImpl) GetSubIndex() (Index, error) { ptr := C.faiss_IndexIDMap2_cast(idx.cPtr()) if ptr == nil { @@ -39,23 +50,5 @@ func (idx *IndexImpl) GetSubIndex() (*IndexImpl, error) { return nil, fmt.Errorf("couldn't retrieve the sub index") } - return &IndexImpl{&faissIndex{subIdx}}, nil -} - -// pass nprobe to be set as index time option for IVF indexes only. -// varying nprobe impacts recall but with an increase in latency. -func (idx *IndexImpl) SetNProbe(nprobe int32) { - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr == nil { - return - } - C.faiss_IndexIVF_set_nprobe(ivfPtr, C.size_t(nprobe)) -} - -func (idx *IndexImpl) GetNProbe() int32 { - ivfPtr := C.faiss_IndexIVF_cast(idx.cPtr()) - if ivfPtr == nil { - return 0 - } - return int32(C.faiss_IndexIVF_nprobe(ivfPtr)) + return &indexImpl{subIdx}, nil }