From 77d4663447f485729d13a12c8fea5b51127d35f6 Mon Sep 17 00:00:00 2001
From: Levi Tamasi <ltamasi@meta.com>
Date: Wed, 15 Jan 2025 18:03:59 -0800
Subject: [PATCH] Extend the test coverage of FaissIVFIndex (#13300)

Summary:
Pull Request resolved: https://github.com/facebook/rocksdb/pull/13300

The patch adds a new unit test for `FaissIVFIndex` that compares its results with a regular in-memory FAISS index. Specifically, it trains two identical IVF indices using the same training vectors, passes the ownership of one to `FaissIVFIndex`, adds the same set of database vectors to both, and then queries them using the same query vectors (with a variety of values for number of neighbors and number of probes).

Reviewed By: jaykorean

Differential Revision: D68233815

fbshipit-source-id: 7577a65c03c7b811707a4dbcd81e69ed85202a51
---
 .../secondary_index/faiss_ivf_index_test.cc   | 143 ++++++++++++++++++
 1 file changed, 143 insertions(+)

diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc
index ead9bc45dcc..3b623800890 100644
--- a/utilities/secondary_index/faiss_ivf_index_test.cc
+++ b/utilities/secondary_index/faiss_ivf_index_test.cc
@@ -290,6 +290,149 @@ TEST(FaissIVFIndexTest, Basic) {
   }
 }
 
+TEST(FaissIVFIndexTest, Compare) {
+  // Train two copies of the same index; hand over one to FaissIVFIndex and use
+  // the other one as a baseline for comparison
+  constexpr size_t dim = 128;
+  auto quantizer_cmp = std::make_unique<faiss::IndexFlatL2>(dim);
+  auto quantizer = std::make_unique<faiss::IndexFlatL2>(dim);
+
+  constexpr size_t num_lists = 16;
+  auto index_cmp = std::make_unique<faiss::IndexIVFFlat>(quantizer_cmp.get(),
+                                                         dim, num_lists);
+  auto index =
+      std::make_unique<faiss::IndexIVFFlat>(quantizer.get(), dim, num_lists);
+
+  {
+    constexpr faiss::idx_t num_train = 1024;
+    std::vector<float> embeddings_train(dim * num_train);
+    faiss::float_rand(embeddings_train.data(), dim * num_train, 42);
+
+    index_cmp->train(num_train, embeddings_train.data());
+    index->train(num_train, embeddings_train.data());
+  }
+
+  const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test");
+  EXPECT_OK(DestroyDB(db_name, Options()));
+
+  Options options;
+  options.create_if_missing = true;
+
+  TransactionDBOptions txn_db_options;
+  txn_db_options.secondary_indices.emplace_back(std::make_shared<FaissIVFIndex>(
+      std::move(index), kDefaultWideColumnName.ToString()));
+
+  TransactionDB* db = nullptr;
+  ASSERT_OK(TransactionDB::Open(options, txn_db_options, db_name, &db));
+
+  std::unique_ptr<TransactionDB> db_guard(db);
+
+  ColumnFamilyOptions cf1_opts;
+  ColumnFamilyHandle* cfh1 = nullptr;
+  ASSERT_OK(db->CreateColumnFamily(cf1_opts, "cf1", &cfh1));
+  std::unique_ptr<ColumnFamilyHandle> cfh1_guard(cfh1);
+
+  ColumnFamilyOptions cf2_opts;
+  ColumnFamilyHandle* cfh2 = nullptr;
+  ASSERT_OK(db->CreateColumnFamily(cf2_opts, "cf2", &cfh2));
+  std::unique_ptr<ColumnFamilyHandle> cfh2_guard(cfh2);
+
+  const auto& secondary_index = txn_db_options.secondary_indices.back();
+  secondary_index->SetPrimaryColumnFamily(cfh1);
+  secondary_index->SetSecondaryColumnFamily(cfh2);
+
+  // Add the same set of database vectors to both indices
+  constexpr faiss::idx_t num_db = 4096;
+
+  {
+    std::vector<float> embeddings_db(dim * num_db);
+    faiss::float_rand(embeddings_db.data(), dim * num_db, 123);
+
+    for (faiss::idx_t i = 0; i < num_db; ++i) {
+      const float* const embedding = embeddings_db.data() + i * dim;
+
+      index_cmp->add(1, embedding);
+
+      const std::string primary_key = std::to_string(i);
+      ASSERT_OK(db->Put(WriteOptions(), cfh1, primary_key,
+                        Slice(reinterpret_cast<const char*>(embedding),
+                              dim * sizeof(float))));
+    }
+  }
+
+  // Search both indices with the same set of query vectors and make sure the
+  // results match
+  {
+    constexpr faiss::idx_t num_query = 32;
+    std::vector<float> embeddings_query(dim * num_query);
+    faiss::float_rand(embeddings_query.data(), dim * num_query, 456);
+
+    for (size_t neighbors : {1, 2, 4}) {
+      for (size_t probes : {1, 2, 4}) {
+        std::unique_ptr<Iterator> underlying_it(
+            db->NewIterator(ReadOptions(), cfh2));
+
+        SecondaryIndexReadOptions read_options;
+        read_options.similarity_search_neighbors = neighbors;
+        read_options.similarity_search_probes = probes;
+
+        std::unique_ptr<Iterator> it =
+            txn_db_options.secondary_indices.back()->NewIterator(
+                read_options, std::move(underlying_it));
+
+        auto get_id = [&]() -> faiss::idx_t {
+          Slice key = it->key();
+          faiss::idx_t id = -1;
+
+          if (std::from_chars(key.data(), key.data() + key.size(), id).ec !=
+              std::errc()) {
+            return -1;
+          }
+
+          return id;
+        };
+
+        for (faiss::idx_t i = 0; i < num_query; ++i) {
+          const float* const embedding = embeddings_query.data() + i * dim;
+
+          std::vector<float> distances(neighbors, 0.0f);
+          std::vector<faiss::idx_t> ids(neighbors, -1);
+
+          faiss::SearchParametersIVF params;
+          params.nprobe = probes;
+
+          index_cmp->search(1, embedding, neighbors, distances.data(),
+                            ids.data(), &params);
+
+          size_t num_found_cmp = 0;
+          for (faiss::idx_t id : ids) {
+            if (id == -1) {
+              break;
+            }
+
+            ++num_found_cmp;
+          }
+
+          size_t num_found = 0;
+          for (it->Seek(Slice(reinterpret_cast<const char*>(embedding),
+                              dim * sizeof(float)));
+               it->Valid(); it->Next()) {
+            const faiss::idx_t id = get_id();
+            ASSERT_GE(id, 0);
+            ASSERT_LT(id, num_db);
+            ASSERT_EQ(id, ids[num_found]);
+
+            ++num_found;
+          }
+
+          ASSERT_OK(it->status());
+          ASSERT_EQ(num_found, num_found_cmp);
+        }
+      }
+    }
+  }
+}
+
 }  // namespace ROCKSDB_NAMESPACE
 
 int main(int argc, char** argv) {