Skip to content
Open
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/blevesearch/go-faiss

go 1.21
go 1.22

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was changed in a recent patch so update.

toolchain go1.23.0
206 changes: 184 additions & 22 deletions index.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package faiss

/*
#include <stdlib.h>
#include <stdint.h>
#include <faiss/c_api/Index_c.h>
#include <faiss/c_api/IndexIVF_c.h>
#include <faiss/c_api/IndexBinary_c.h>
#include <faiss/c_api/IndexIVF_c_ex.h>
#include <faiss/c_api/Index_c_ex.h>
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
#include <faiss/c_api/index_factory_c.h>
#include <faiss/c_api/MetaIndexes_c.h>
#include <faiss/c_api/IndexBinary_c.h>
*/
import "C"
import (
Expand Down Expand Up @@ -36,13 +39,13 @@ type Index interface {
MetricType() int

// Train trains the index on a representative set of vectors.
Train(x []float32) error
Train(x interface{}) error

// Add adds vectors to the index.
Add(x []float32) error
Add(x interface{}) error

// AddWithIDs is like Add, but stores xids instead of sequential IDs.
AddWithIDs(x []float32, xids []int64) error
AddWithIDs(x interface{}, xids []int64) error

// Returns true if the index is an IVF index.
IsIVFIndex() bool
Expand Down Expand Up @@ -75,6 +78,12 @@ type Index interface {
SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32,
labels []int64, err error)

SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32,
labels []int64, err error)

SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
labels []int64, err error)

// Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs
SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64,
minEligibleCentroids int, k int64, x, centroidDis []float32,
Expand Down Expand Up @@ -104,50 +113,105 @@ type Index interface {
Size() uint64

cPtr() *C.FaissIndex

cPtrBinary() *C.FaissIndexBinary

DistCompute(queryData []float32, ids []int64, k int, distances []float32) error
}

type faissIndex struct {
idx *C.FaissIndex
idx *C.FaissIndex
idxBinary *C.FaissIndexBinary
}

func (idx *faissIndex) cPtr() *C.FaissIndex {
return idx.idx
}

func (idx *faissIndex) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error {
if c := C.faiss_Index_dist_compute(idx.idx, (*C.float)(&queryData[0]),
(*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 {
return getLastError()
}

return nil
}

func (idx *faissIndex) cPtrBinary() *C.FaissIndexBinary {
return idx.idxBinary
}

func (idx *faissIndex) Size() uint64 {
size := C.faiss_Index_size(idx.idx)
return uint64(size)
}

func (idx *faissIndex) D() int {
return int(C.faiss_Index_d(idx.idx))
if idx.idx != nil {
return int(C.faiss_Index_d(idx.idx))
}
return int(C.faiss_IndexBinary_d(idx.idxBinary))
}

func (idx *faissIndex) IsTrained() bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If index is binary then calling this method would panic right as you should be calling

return C.faiss_IndexBinary_is_trained(idx.idx) != 0

https://github.com/blevesearch/faiss/blob/b3d4e00a69425b95e0b283da7801efc9f66b580d/c_api/IndexBinary_c.h#L35

return C.faiss_Index_is_trained(idx.idx) != 0
}

func (idx *faissIndex) Ntotal() int64 {
if idx.idxBinary != nil {
return int64(C.faiss_IndexBinary_ntotal(idx.idxBinary))
}
return int64(C.faiss_Index_ntotal(idx.idx))
}

func (idx *faissIndex) MetricType() int {
return int(C.faiss_Index_metric_type(idx.idx))
}

func (idx *faissIndex) Train(x []float32) error {
n := len(x) / idx.D()
if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 {
return getLastError()
func (idx *faissIndex) Train(x interface{}) error {
floatVec, ok := x.([]float32)
if ok {
n := len(floatVec) / idx.D()
if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&floatVec[0])); c != 0 {
return getLastError()
}
} else {
c, ok := x.([]uint8)
if ok {
n := (len(c) * 8) / idx.D()
if c := C.faiss_IndexBinary_train(idx.idxBinary, C.idx_t(n), (*C.uint8_t)(&c[0])); c != 0 {
return getLastError()
}
}
}
return nil
}

func (idx *faissIndex) Add(x []float32) error {
n := len(x) / idx.D()
if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 {
return getLastError()
func (idx *faissIndex) Add(x interface{}) error {
floatVec, ok := x.([]float32)
if ok {
n := len(floatVec) / idx.D()
if c := C.faiss_Index_add(
idx.idx,
C.idx_t(n),
(*C.float)(&floatVec[0]),
); c != 0 {
return getLastError()
}
} else {
c, ok := x.([]uint8)
if ok {
n := (len(c) * 8) / idx.D()
if c := C.faiss_IndexBinary_add(
idx.idxBinary,
C.idx_t(n),
(*C.uint8_t)(&c[0]),
); c != 0 {
return getLastError()
}
}
}

return nil
}

Expand Down Expand Up @@ -257,16 +321,33 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector,
return distances, labels, nil
}

func (idx *faissIndex) AddWithIDs(x []float32, xids []int64) error {
n := len(x) / idx.D()
if c := C.faiss_Index_add_with_ids(
idx.idx,
C.idx_t(n),
(*C.float)(&x[0]),
(*C.idx_t)(&xids[0]),
); c != 0 {
return getLastError()
func (idx *faissIndex) AddWithIDs(x interface{}, xids []int64) error {
floatVec, ok := x.([]float32)
if ok {
n := len(floatVec) / idx.D()
if c := C.faiss_Index_add_with_ids(
idx.idx,
C.idx_t(n),
(*C.float)(&floatVec[0]),
(*C.idx_t)(&xids[0]),
); c != 0 {
return getLastError()
}
} else {
c, ok := x.([]uint8)
if ok {
n := (len(c) * 8) / idx.D()
if c := C.faiss_IndexBinary_add_with_ids(
idx.idxBinary,
C.idx_t(n),
(*C.uint8_t)(&c[0]),
(*C.idx_t)(&xids[0]),
); c != 0 {
return getLastError()
}
}
}

return nil
}

Expand Down Expand Up @@ -318,6 +399,75 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
return
}

func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • please add all binary index related methods(and the struct) to a new file, after splitting the interface thanks
  • The method naming is off, youre just searching the binary vector with a given K - No IDs for an include selector is given, Rename to Just SearchBinary

params json.RawMessage) (distances []int32, labels []int64, err error,
) {
d := idx.D()
nq := (len(x) * 8) / d

distances = make([]int32, int64(nq)*k)
labels = make([]int64, int64(nq)*k)

searchParams, err := NewSearchParams(idx, params, nil, nil)
if err != nil {
return nil, nil, err
}
defer searchParams.Delete()

if c := C.faiss_IndexBinary_search_with_params(
idx.idxBinary,
C.idx_t(nq),
(*C.uint8_t)(&x[0]),
C.idx_t(k),
searchParams.sp,
(*C.int32_t)(&distances[0]),
(*C.idx_t)(&labels[0]),
); c != 0 {
err = getLastError()
}

return distances, labels, nil
}

func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64,params json.RawMessage) (distances []int32,
labels []int64, err error) {
d := idx.D()
nq := (len(x) * 8) / d

distances = make([]int32, int64(nq)*k)
labels = make([]int64, int64(nq)*k)

var selector *C.FaissIDSelector
if len(exclude) > 0 {
excludeSelector, err := NewIDSelectorNot(exclude)
if err != nil {
return nil, nil, err
}
selector = excludeSelector.Get()
defer excludeSelector.Delete()
}

searchParams, err := NewSearchParams(idx, params, selector, nil)
if err != nil {
return nil, nil, err
}
defer searchParams.Delete()

if c := C.faiss_IndexBinary_search_with_params(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code block can be reused. Add searchBinaryWithParams method similar to searchWithParams

idx.idxBinary,
C.idx_t(nq),
(*C.uint8_t)(&x[0]),
C.idx_t(k),
searchParams.sp,
(*C.int32_t)(&distances[0]),
(*C.idx_t)(&labels[0]),
); c != 0 {
err = getLastError()
}

return distances, labels, err
}

func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64,
params json.RawMessage) (distances []float32, labels []int64, err error,
) {
Expand Down Expand Up @@ -426,6 +576,7 @@ func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) {

func (idx *faissIndex) Close() {
C.faiss_Index_free(idx.idx)
C.faiss_IndexBinary_free(idx.idxBinary)
}

func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) (
Expand Down Expand Up @@ -507,6 +658,17 @@ func IndexFactory(d int, description string, metric int) (*IndexImpl, error) {
return &IndexImpl{&idx}, nil
}

func IndexBinaryFactory(d int, description string, metric int) (*IndexImpl, error) {
cdesc := C.CString(description)
defer C.free(unsafe.Pointer(cdesc))
var idx faissIndex
c := C.faiss_index_binary_factory(&idx.idxBinary, C.int(d), cdesc)
if c != 0 {
return nil, getLastError()
}
return &IndexImpl{&idx}, nil
}

func SetOMPThreads(n uint) {
C.faiss_set_omp_threads(C.uint(n))
}
2 changes: 1 addition & 1 deletion index_flat.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ func (idx *IndexImpl) AsFlat() *IndexFlat {
if ptr == nil {
panic("index is not a flat index")
}
return &IndexFlat{&faissIndex{ptr}}
return &IndexFlat{&faissIndex{idx: ptr}}
}
Loading