diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc index ead9bc45dcc..3b623800890 100644 --- a/utilities/secondary_index/faiss_ivf_index_test.cc +++ b/utilities/secondary_index/faiss_ivf_index_test.cc @@ -290,6 +290,149 @@ TEST(FaissIVFIndexTest, Basic) { } } +TEST(FaissIVFIndexTest, Compare) { + // Train two copies of the same index; hand over one to FaissIVFIndex and use + // the other one as a baseline for comparison + constexpr size_t dim = 128; + auto quantizer_cmp = std::make_unique(dim); + auto quantizer = std::make_unique(dim); + + constexpr size_t num_lists = 16; + auto index_cmp = std::make_unique(quantizer_cmp.get(), + dim, num_lists); + auto index = + std::make_unique(quantizer.get(), dim, num_lists); + + { + constexpr faiss::idx_t num_train = 1024; + std::vector embeddings_train(dim * num_train); + faiss::float_rand(embeddings_train.data(), dim * num_train, 42); + + index_cmp->train(num_train, embeddings_train.data()); + index->train(num_train, embeddings_train.data()); + } + + const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test"); + EXPECT_OK(DestroyDB(db_name, Options())); + + Options options; + options.create_if_missing = true; + + TransactionDBOptions txn_db_options; + txn_db_options.secondary_indices.emplace_back(std::make_shared( + std::move(index), kDefaultWideColumnName.ToString())); + + TransactionDB* db = nullptr; + ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db)); + + std::unique_ptr db_guard(db); + + ColumnFamilyOptions cf1_opts; + ColumnFamilyHandle* cfh1 = nullptr; + ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1)); + std::unique_ptr cfh1_guard(cfh1); + + ColumnFamilyOptions cf2_opts; + ColumnFamilyHandle* cfh2 = nullptr; + ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2)); + std::unique_ptr cfh2_guard(cfh2); + + const auto& secondary_index = txn_db_options.secondary_indices.back(); + secondary_index->SetPrimaryColumnFamily(cfh1); + secondary_index->SetSecondaryColumnFamily(cfh2); + + // Add the same set of database vectors to both indices + constexpr faiss::idx_t num_db = 4096; + + { + std::vector embeddings_db(dim * num_db); + faiss::float_rand(embeddings_db.data(), dim * num_db, 123); + + for (faiss::idx_t i = 0; i < num_db; ++i) { + const float* const embedding = embeddings_db.data() + i * dim; + + index_cmp->add(1, embedding); + + const std::string primary_key = std::to_string(i); + ASSERT_OK(db->Put(WriteOptions(), cfh1, primary_key, + Slice(reinterpret_cast(embedding), + dim * sizeof(float)))); + } + } + + // Search both indices with the same set of query vectors and make sure the + // results match + { + constexpr faiss::idx_t num_query = 32; + std::vector embeddings_query(dim * num_query); + faiss::float_rand(embeddings_query.data(), dim * num_query, 456); + + for (size_t neighbors : {1, 2, 4}) { + for (size_t probes : {1, 2, 4}) { + std::unique_ptr underlying_it( + db->NewIterator(ReadOptions(), cfh2)); + + SecondaryIndexReadOptions read_options; + read_options.similarity_search_neighbors = neighbors; + read_options.similarity_search_probes = probes; + + std::unique_ptr it = + txn_db_options.secondary_indices.back()->NewIterator( + read_options, std::move(underlying_it)); + + auto get_id = [&]() -> faiss::idx_t { + Slice key = it->key(); + faiss::idx_t id = -1; + + if (std::from_chars(key.data(), key.data() + key.size(), id).ec != + std::errc()) { + return -1; + } + + return id; + }; + + for (faiss::idx_t i = 0; i < num_query; ++i) { + const float* const embedding = embeddings_query.data() + i * dim; + + std::vector distances(neighbors, 0.0f); + std::vector ids(neighbors, -1); + + faiss::SearchParametersIVF params; + params.nprobe = probes; + + index_cmp->search(1, embedding, neighbors, distances.data(), + ids.data(), ¶ms); + + size_t num_found_cmp = 0; + for (faiss::idx_t id : ids) { + if (id == -1) { + break; + } + + ++num_found_cmp; + } + + size_t num_found = 0; + for (it->Seek(Slice(reinterpret_cast(embedding), + dim * sizeof(float))); + it->Valid(); it->Next()) { + const faiss::idx_t id = get_id(); + ASSERT_GE(id, 0); + ASSERT_LT(id, num_db); + ASSERT_EQ(id, ids[num_found]); + + ++num_found; + } + + ASSERT_OK(it->status()); + ASSERT_EQ(num_found, num_found_cmp); + } + } + } + } +} + } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) {