Skip to content

Commit ef9e45a

Browse files
authored
Fix CAGRA-HNSW serialization format (#1108)
Signed-off-by: Mickael Ide <[email protected]>
1 parent 810a0c8 commit ef9e45a

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/common/cuvs/proto/cuvs_index.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,20 @@ struct cuvs_index {
268268
bool include_dataset = true) {
269269
auto const& underlying_index = index.get_vector_index();
270270
if constexpr (vector_index_kind == cuvs_index_kind::cagra) {
271+
size_t metric_type;
272+
if (underlying_index.metric() == cuvs::distance::DistanceType::L2Expanded) {
273+
metric_type = 0;
274+
} else if (underlying_index.metric() == cuvs::distance::DistanceType::InnerProduct) {
275+
metric_type = 1;
276+
} else if (underlying_index.metric() == cuvs::distance::DistanceType::CosineExpanded) {
277+
metric_type = 2;
278+
}
279+
280+
os.write(reinterpret_cast<char*>(&metric_type), sizeof(metric_type));
281+
size_t data_size = underlying_index.dim() * sizeof(float);
282+
os.write(reinterpret_cast<char*>(&data_size), sizeof(data_size));
283+
size_t dim = underlying_index.dim();
284+
os.write(reinterpret_cast<char*>(&dim), sizeof(dim));
271285
return cuvs::neighbors::cagra::serialize_to_hnswlib(res, os, underlying_index);
272286
}
273287
}

tests/ut/test_gpu_search.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ TEST_CASE("Test All GPU Index", "[search]") {
6565
return json;
6666
};
6767

68+
auto cagra_hnsw_gen = [](auto&& upstream_gen) {
69+
return [upstream_gen]() {
70+
knowhere::Json json = upstream_gen();
71+
json[knowhere::indexparam::ADAPT_FOR_CPU] = true;
72+
json[knowhere::indexparam::EF] = 128;
73+
return json;
74+
};
75+
};
76+
6877
auto refined_gen = [](auto&& upstream_gen) {
6978
return [upstream_gen]() {
7079
knowhere::Json json = upstream_gen();
@@ -210,6 +219,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
210219
make_tuple(knowhere::IndexEnum::INDEX_CUVS_IVFPQ, ivfpq_gen),
211220
make_tuple(knowhere::IndexEnum::INDEX_CUVS_IVFPQ, refined_gen(ivfpq_gen)),
212221
make_tuple(knowhere::IndexEnum::INDEX_CUVS_CAGRA, cagra_gen),
222+
make_tuple(knowhere::IndexEnum::INDEX_CUVS_CAGRA, cagra_hnsw_gen(cagra_gen)),
213223
}));
214224

215225
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();

0 commit comments

Comments
 (0)