-
Notifications
You must be signed in to change notification settings - Fork 6
MB-62985 - Add functionality to support binary quantised vectors. #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 9 commits
f7c8908
e577f27
73e9e4b
68a53a2
4be47b7
a33ef3e
42a99f0
0d9a4f3
5ad828d
329d981
bfb5436
b3fff7e
5774cac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| module github.com/blevesearch/go-faiss | ||
|
|
||
| go 1.21 | ||
| go 1.22 | ||
|
|
||
| toolchain go1.23.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,16 @@ package faiss | |
|
|
||
| /* | ||
| #include <stdlib.h> | ||
| #include <stdint.h> | ||
| #include <faiss/c_api/Index_c.h> | ||
| #include <faiss/c_api/IndexIVF_c.h> | ||
| #include <faiss/c_api/IndexBinary_c.h> | ||
| #include <faiss/c_api/IndexIVF_c_ex.h> | ||
| #include <faiss/c_api/Index_c_ex.h> | ||
| #include <faiss/c_api/impl/AuxIndexStructures_c.h> | ||
| #include <faiss/c_api/index_factory_c.h> | ||
| #include <faiss/c_api/MetaIndexes_c.h> | ||
| #include <faiss/c_api/IndexBinary_c.h> | ||
| */ | ||
| import "C" | ||
| import ( | ||
|
|
@@ -36,13 +39,13 @@ type Index interface { | |
| MetricType() int | ||
|
|
||
| // Train trains the index on a representative set of vectors. | ||
| Train(x []float32) error | ||
| Train(x interface{}) error | ||
metonymic-smokey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // Add adds vectors to the index. | ||
| Add(x []float32) error | ||
| Add(x interface{}) error | ||
|
|
||
| // AddWithIDs is like Add, but stores xids instead of sequential IDs. | ||
| AddWithIDs(x []float32, xids []int64) error | ||
| AddWithIDs(x interface{}, xids []int64) error | ||
|
|
||
| // Returns true if the index is an IVF index. | ||
| IsIVFIndex() bool | ||
|
|
@@ -75,6 +78,12 @@ type Index interface { | |
| SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32, | ||
| labels []int64, err error) | ||
|
|
||
| SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32, | ||
| labels []int64, err error) | ||
|
|
||
| SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32, | ||
| labels []int64, err error) | ||
|
|
||
| // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs | ||
| SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64, | ||
| minEligibleCentroids int, k int64, x, centroidDis []float32, | ||
|
|
@@ -104,50 +113,105 @@ type Index interface { | |
| Size() uint64 | ||
|
|
||
| cPtr() *C.FaissIndex | ||
|
|
||
| cPtrBinary() *C.FaissIndexBinary | ||
|
|
||
| DistCompute(queryData []float32, ids []int64, k int, distances []float32) error | ||
| } | ||
|
|
||
| type faissIndex struct { | ||
| idx *C.FaissIndex | ||
| idx *C.FaissIndex | ||
| idxBinary *C.FaissIndexBinary | ||
| } | ||
|
|
||
| func (idx *faissIndex) cPtr() *C.FaissIndex { | ||
| return idx.idx | ||
| } | ||
|
|
||
| func (idx *faissIndex) DistCompute(queryData []float32, ids []int64, k int, distances []float32) error { | ||
| if c := C.faiss_Index_dist_compute(idx.idx, (*C.float)(&queryData[0]), | ||
| (*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0])); c != 0 { | ||
| return getLastError() | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| func (idx *faissIndex) cPtrBinary() *C.FaissIndexBinary { | ||
| return idx.idxBinary | ||
| } | ||
|
|
||
| func (idx *faissIndex) Size() uint64 { | ||
| size := C.faiss_Index_size(idx.idx) | ||
| return uint64(size) | ||
| } | ||
|
|
||
| func (idx *faissIndex) D() int { | ||
| return int(C.faiss_Index_d(idx.idx)) | ||
| if idx.idx != nil { | ||
| return int(C.faiss_Index_d(idx.idx)) | ||
| } | ||
| return int(C.faiss_IndexBinary_d(idx.idxBinary)) | ||
| } | ||
|
|
||
| func (idx *faissIndex) IsTrained() bool { | ||
|
||
| return C.faiss_Index_is_trained(idx.idx) != 0 | ||
| } | ||
|
|
||
| func (idx *faissIndex) Ntotal() int64 { | ||
| if idx.idxBinary != nil { | ||
| return int64(C.faiss_IndexBinary_ntotal(idx.idxBinary)) | ||
| } | ||
| return int64(C.faiss_Index_ntotal(idx.idx)) | ||
| } | ||
|
|
||
| func (idx *faissIndex) MetricType() int { | ||
| return int(C.faiss_Index_metric_type(idx.idx)) | ||
| } | ||
|
|
||
| func (idx *faissIndex) Train(x []float32) error { | ||
| n := len(x) / idx.D() | ||
| if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { | ||
| return getLastError() | ||
| func (idx *faissIndex) Train(x interface{}) error { | ||
metonymic-smokey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| floatVec, ok := x.([]float32) | ||
| if ok { | ||
| n := len(floatVec) / idx.D() | ||
| if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&floatVec[0])); c != 0 { | ||
| return getLastError() | ||
| } | ||
| } else { | ||
| c, ok := x.([]uint8) | ||
| if ok { | ||
| n := (len(c) * 8) / idx.D() | ||
| if c := C.faiss_IndexBinary_train(idx.idxBinary, C.idx_t(n), (*C.uint8_t)(&c[0])); c != 0 { | ||
| return getLastError() | ||
| } | ||
| } | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (idx *faissIndex) Add(x []float32) error { | ||
| n := len(x) / idx.D() | ||
| if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 { | ||
| return getLastError() | ||
| func (idx *faissIndex) Add(x interface{}) error { | ||
metonymic-smokey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| floatVec, ok := x.([]float32) | ||
| if ok { | ||
| n := len(floatVec) / idx.D() | ||
| if c := C.faiss_Index_add( | ||
| idx.idx, | ||
| C.idx_t(n), | ||
| (*C.float)(&floatVec[0]), | ||
| ); c != 0 { | ||
| return getLastError() | ||
| } | ||
| } else { | ||
| c, ok := x.([]uint8) | ||
| if ok { | ||
| n := (len(c) * 8) / idx.D() | ||
| if c := C.faiss_IndexBinary_add( | ||
| idx.idxBinary, | ||
| C.idx_t(n), | ||
| (*C.uint8_t)(&c[0]), | ||
| ); c != 0 { | ||
| return getLastError() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
|
|
@@ -257,16 +321,33 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector, | |
| return distances, labels, nil | ||
| } | ||
|
|
||
| func (idx *faissIndex) AddWithIDs(x []float32, xids []int64) error { | ||
| n := len(x) / idx.D() | ||
| if c := C.faiss_Index_add_with_ids( | ||
| idx.idx, | ||
| C.idx_t(n), | ||
| (*C.float)(&x[0]), | ||
| (*C.idx_t)(&xids[0]), | ||
| ); c != 0 { | ||
| return getLastError() | ||
| func (idx *faissIndex) AddWithIDs(x interface{}, xids []int64) error { | ||
| floatVec, ok := x.([]float32) | ||
| if ok { | ||
| n := len(floatVec) / idx.D() | ||
| if c := C.faiss_Index_add_with_ids( | ||
| idx.idx, | ||
| C.idx_t(n), | ||
| (*C.float)(&floatVec[0]), | ||
| (*C.idx_t)(&xids[0]), | ||
| ); c != 0 { | ||
| return getLastError() | ||
| } | ||
| } else { | ||
| c, ok := x.([]uint8) | ||
| if ok { | ||
| n := (len(c) * 8) / idx.D() | ||
| if c := C.faiss_IndexBinary_add_with_ids( | ||
| idx.idxBinary, | ||
| C.idx_t(n), | ||
| (*C.uint8_t)(&c[0]), | ||
| (*C.idx_t)(&xids[0]), | ||
| ); c != 0 { | ||
| return getLastError() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
|
|
@@ -318,6 +399,75 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p | |
| return | ||
| } | ||
|
|
||
| func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, | ||
|
||
| params json.RawMessage) (distances []int32, labels []int64, err error, | ||
| ) { | ||
| d := idx.D() | ||
| nq := (len(x) * 8) / d | ||
|
|
||
| distances = make([]int32, int64(nq)*k) | ||
| labels = make([]int64, int64(nq)*k) | ||
|
|
||
| searchParams, err := NewSearchParams(idx, params, nil, nil) | ||
| if err != nil { | ||
| return nil, nil, err | ||
| } | ||
| defer searchParams.Delete() | ||
|
|
||
| if c := C.faiss_IndexBinary_search_with_params( | ||
| idx.idxBinary, | ||
| C.idx_t(nq), | ||
| (*C.uint8_t)(&x[0]), | ||
| C.idx_t(k), | ||
| searchParams.sp, | ||
| (*C.int32_t)(&distances[0]), | ||
| (*C.idx_t)(&labels[0]), | ||
| ); c != 0 { | ||
| err = getLastError() | ||
| } | ||
|
|
||
| return distances, labels, nil | ||
| } | ||
|
|
||
| func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64,params json.RawMessage) (distances []int32, | ||
| labels []int64, err error) { | ||
| d := idx.D() | ||
| nq := (len(x) * 8) / d | ||
|
|
||
| distances = make([]int32, int64(nq)*k) | ||
| labels = make([]int64, int64(nq)*k) | ||
|
|
||
| var selector *C.FaissIDSelector | ||
| if len(exclude) > 0 { | ||
| excludeSelector, err := NewIDSelectorNot(exclude) | ||
| if err != nil { | ||
| return nil, nil, err | ||
| } | ||
| selector = excludeSelector.Get() | ||
| defer excludeSelector.Delete() | ||
| } | ||
|
|
||
| searchParams, err := NewSearchParams(idx, params, selector, nil) | ||
| if err != nil { | ||
| return nil, nil, err | ||
| } | ||
| defer searchParams.Delete() | ||
|
|
||
| if c := C.faiss_IndexBinary_search_with_params( | ||
|
||
| idx.idxBinary, | ||
| C.idx_t(nq), | ||
| (*C.uint8_t)(&x[0]), | ||
| C.idx_t(k), | ||
| searchParams.sp, | ||
| (*C.int32_t)(&distances[0]), | ||
| (*C.idx_t)(&labels[0]), | ||
| ); c != 0 { | ||
| err = getLastError() | ||
| } | ||
|
|
||
| return distances, labels, err | ||
| } | ||
|
|
||
| func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64, | ||
| params json.RawMessage) (distances []float32, labels []int64, err error, | ||
| ) { | ||
|
|
@@ -426,6 +576,7 @@ func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) { | |
|
|
||
| func (idx *faissIndex) Close() { | ||
| C.faiss_Index_free(idx.idx) | ||
| C.faiss_IndexBinary_free(idx.idxBinary) | ||
| } | ||
|
|
||
| func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) ( | ||
|
|
@@ -507,6 +658,17 @@ func IndexFactory(d int, description string, metric int) (*IndexImpl, error) { | |
| return &IndexImpl{&idx}, nil | ||
| } | ||
|
|
||
| func IndexBinaryFactory(d int, description string, metric int) (*IndexImpl, error) { | ||
metonymic-smokey marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| cdesc := C.CString(description) | ||
| defer C.free(unsafe.Pointer(cdesc)) | ||
| var idx faissIndex | ||
| c := C.faiss_index_binary_factory(&idx.idxBinary, C.int(d), cdesc) | ||
| if c != 0 { | ||
| return nil, getLastError() | ||
| } | ||
| return &IndexImpl{&idx}, nil | ||
| } | ||
|
|
||
| func SetOMPThreads(n uint) { | ||
| C.faiss_set_omp_threads(C.uint(n)) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was changed in a recent patch so update.