From f12be15333b61b398439cb4c4bc99701488a2244 Mon Sep 17 00:00:00 2001 From: Andrew Prudhomme Date: Wed, 18 Dec 2024 15:25:33 -0800 Subject: [PATCH] Add vector search support for fields in nested objects --- .../server/field/VectorFieldDef.java | 44 ++++- .../field/properties/VectorQueryable.java | 5 +- .../nrtsearch/server/index/IndexState.java | 24 +++ ...iversifyingChildrenByteKnnVectorQuery.java | 57 ++++++ ...versifyingChildrenFloatKnnVectorQuery.java | 57 ++++++ .../server/search/SearchRequestProcessor.java | 24 ++- .../server/field/VectorFieldDefTest.java | 187 +++++++++++++++++- .../server/index/IndexStateTest.java | 122 ++++++++++++ .../registerFieldsNestedVectorSearch.json | 39 ++++ 9 files changed, 549 insertions(+), 10 deletions(-) create mode 100644 src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenByteKnnVectorQuery.java create mode 100644 src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenFloatKnnVectorQuery.java create mode 100644 src/test/java/com/yelp/nrtsearch/server/index/IndexStateTest.java create mode 100644 src/test/resources/field/registerFieldsNestedVectorSearch.json diff --git a/src/main/java/com/yelp/nrtsearch/server/field/VectorFieldDef.java b/src/main/java/com/yelp/nrtsearch/server/field/VectorFieldDef.java index a969baad9..7574906a0 100644 --- a/src/main/java/com/yelp/nrtsearch/server/field/VectorFieldDef.java +++ b/src/main/java/com/yelp/nrtsearch/server/field/VectorFieldDef.java @@ -25,6 +25,8 @@ import com.yelp.nrtsearch.server.grpc.Field; import com.yelp.nrtsearch.server.grpc.KnnQuery; import com.yelp.nrtsearch.server.grpc.VectorIndexingOptions; +import com.yelp.nrtsearch.server.query.vector.NrtDiversifyingChildrenByteKnnVectorQuery; +import com.yelp.nrtsearch.server.query.vector.NrtDiversifyingChildrenFloatKnnVectorQuery; import com.yelp.nrtsearch.server.query.vector.NrtKnnByteVectorQuery; import com.yelp.nrtsearch.server.query.vector.NrtKnnFloatVectorQuery; import com.yelp.nrtsearch.server.vector.ByteVectorType; @@ -51,6 +53,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; @@ -113,9 +116,15 @@ enum VectorSearchType { * @param k number of nearest neighbors to return * @param filterQuery filter query * @param numCandidates number of candidates to search per segment + * @param parentBitSetProducer parent bit set producer for searching nested fields, or null * @return knn query */ - abstract Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates); + abstract Query getTypeKnnQuery( + KnnQuery knnQuery, + int k, + Query filterQuery, + int numCandidates, + BitSetProducer parentBitSetProducer); private static VectorSimilarityFunction getSimilarityFunction(String vectorSimilarity) { VectorSimilarityFunction similarityFunction = SIMILARITY_FUNCTION_MAP.get(vectorSimilarity); @@ -307,7 +316,8 @@ public void parseDocumentField( } @Override - public Query getKnnQuery(KnnQuery knnQuery, Query filterQuery) { + public Query getKnnQuery( + KnnQuery knnQuery, Query filterQuery, BitSetProducer parentBitSetProducer) { if (!isSearchable()) { throw new IllegalArgumentException("Vector field is not searchable: " + getName()); } @@ -324,7 +334,7 @@ public Query getKnnQuery(KnnQuery knnQuery, Query filterQuery) { throw new IllegalArgumentException("Vector search numCandidates > " + NUM_CANDIDATES_LIMIT); } - return getTypeKnnQuery(knnQuery, k, filterQuery, numCandidates); + return getTypeKnnQuery(knnQuery, k, filterQuery, numCandidates, parentBitSetProducer); } /** Field class for 'FLOAT' vector field type. */ @@ -420,7 +430,12 @@ static byte[] convertFloatArrToBytes(float[] floatArr) { } @Override - Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates) { + Query getTypeKnnQuery( + KnnQuery knnQuery, + int k, + Query filterQuery, + int numCandidates, + BitSetProducer parentBitSetProducer) { if (knnQuery.getQueryVectorCount() != getVectorDimensions()) { throw new IllegalArgumentException( "Invalid query vector size, expected: " @@ -433,7 +448,12 @@ Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandid queryVector[i] = knnQuery.getQueryVector(i); } validateVectorForSearch(queryVector); - return new NrtKnnFloatVectorQuery(getName(), queryVector, k, filterQuery, numCandidates); + if (parentBitSetProducer != null) { + return new NrtDiversifyingChildrenFloatKnnVectorQuery( + getName(), queryVector, filterQuery, k, numCandidates, parentBitSetProducer); + } else { + return new NrtKnnFloatVectorQuery(getName(), queryVector, k, filterQuery, numCandidates); + } } /** @@ -560,7 +580,12 @@ byte[] parseVectorFieldToByteArr(String fieldValueJson) { } @Override - Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates) { + Query getTypeKnnQuery( + KnnQuery knnQuery, + int k, + Query filterQuery, + int numCandidates, + BitSetProducer parentBitSetProducer) { if (knnQuery.getQueryByteVector().size() != getVectorDimensions()) { throw new IllegalArgumentException( "Invalid query byte vector size, expected: " @@ -570,7 +595,12 @@ Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandid } byte[] queryVector = knnQuery.getQueryByteVector().toByteArray(); validateVectorForSearch(queryVector); - return new NrtKnnByteVectorQuery(getName(), queryVector, k, filterQuery, numCandidates); + if (parentBitSetProducer != null) { + return new NrtDiversifyingChildrenByteKnnVectorQuery( + getName(), queryVector, filterQuery, k, numCandidates, parentBitSetProducer); + } else { + return new NrtKnnByteVectorQuery(getName(), queryVector, k, filterQuery, numCandidates); + } } /** diff --git a/src/main/java/com/yelp/nrtsearch/server/field/properties/VectorQueryable.java b/src/main/java/com/yelp/nrtsearch/server/field/properties/VectorQueryable.java index 18f86d71c..a833e1878 100644 --- a/src/main/java/com/yelp/nrtsearch/server/field/properties/VectorQueryable.java +++ b/src/main/java/com/yelp/nrtsearch/server/field/properties/VectorQueryable.java @@ -18,6 +18,7 @@ import com.yelp.nrtsearch.server.field.FieldDef; import com.yelp.nrtsearch.server.grpc.KnnQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; /** Trait interface for {@link FieldDef} types that can be queried by a {@link KnnQuery}. */ public interface VectorQueryable { @@ -26,7 +27,9 @@ public interface VectorQueryable { * * @param knnQuery knn query configuration * @param filterQuery query to filter knn search, or null + * @param parentBitSetProducer bit set producer for parent documents when searching a nested + * field, or null * @return lucene knn query */ - Query getKnnQuery(KnnQuery knnQuery, Query filterQuery); + Query getKnnQuery(KnnQuery knnQuery, Query filterQuery, BitSetProducer parentBitSetProducer); } diff --git a/src/main/java/com/yelp/nrtsearch/server/index/IndexState.java b/src/main/java/com/yelp/nrtsearch/server/index/IndexState.java index 75452bd82..e2e3f15fd 100644 --- a/src/main/java/com/yelp/nrtsearch/server/index/IndexState.java +++ b/src/main/java/com/yelp/nrtsearch/server/index/IndexState.java @@ -317,6 +317,30 @@ public String resolveQueryNestedPath(String path) { throw new IllegalArgumentException("Nested path is not a nested object field: " + path); } + /** + * Get the base path for the nested document containing the field at the given path. For fields in + * the base document, this returns _root. The base nested path for _root is null. + * + * @param path field path + * @return nested base path, or null + */ + public static String getFieldBaseNestedPath(String path, IndexState indexState) { + Objects.requireNonNull(path, "path cannot be null"); + if (path.equals(IndexState.ROOT)) { + return null; + } + + String currentPath = path; + while (currentPath.contains(".")) { + currentPath = currentPath.substring(0, currentPath.lastIndexOf(".")); + FieldDef fieldDef = indexState.getFieldOrThrow(currentPath); + if (fieldDef instanceof ObjectFieldDef objFieldDef && objFieldDef.isNestedDoc()) { + return currentPath; + } + } + return IndexState.ROOT; + } + /** Get index state info. */ public abstract IndexStateInfo getIndexStateInfo(); diff --git a/src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenByteKnnVectorQuery.java b/src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenByteKnnVectorQuery.java new file mode 100644 index 000000000..ebfeb53de --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenByteKnnVectorQuery.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.query.vector; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; + +/** + * A {@link DiversifyingChildrenByteKnnVectorQuery} that has its functionality slightly modified. + * The {@link TotalHits} from the vector search are available after the query has been rewritten + * using the {@link WithVectorTotalHits} interface. The results merging has also been modified to + * produce the top k hits from the top numCandidates hits from each leaf. + */ +public class NrtDiversifyingChildrenByteKnnVectorQuery + extends DiversifyingChildrenByteKnnVectorQuery implements WithVectorTotalHits { + private final int topHits; + private TotalHits totalHits; + + public NrtDiversifyingChildrenByteKnnVectorQuery( + String field, + byte[] target, + Query filter, + int k, + int numCandidates, + BitSetProducer parentsFilter) { + super(field, target, filter, numCandidates, parentsFilter); + this.topHits = k; + } + + @Override + public TotalHits getTotalHits() { + return totalHits; + } + + @Override + protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { + TopDocs topDocs = TopDocs.merge(topHits, perLeafResults); + totalHits = topDocs.totalHits; + return topDocs; + } +} diff --git a/src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenFloatKnnVectorQuery.java b/src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenFloatKnnVectorQuery.java new file mode 100644 index 000000000..9759a61bf --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/query/vector/NrtDiversifyingChildrenFloatKnnVectorQuery.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.query.vector; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; + +/** + * A {@link DiversifyingChildrenFloatKnnVectorQuery} that has its functionality slightly modified. + * The {@link TotalHits} from the vector search are available after the query has been rewritten + * using the {@link WithVectorTotalHits} interface. The results merging has also been modified to + * produce the top k hits from the top numCandidates hits from each leaf. + */ +public class NrtDiversifyingChildrenFloatKnnVectorQuery + extends DiversifyingChildrenFloatKnnVectorQuery implements WithVectorTotalHits { + private final int topHits; + private TotalHits totalHits; + + public NrtDiversifyingChildrenFloatKnnVectorQuery( + String field, + float[] target, + Query filter, + int k, + int numCandidates, + BitSetProducer parentsFilter) { + super(field, target, filter, numCandidates, parentsFilter); + this.topHits = k; + } + + @Override + public TotalHits getTotalHits() { + return totalHits; + } + + @Override + protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { + TopDocs topDocs = TopDocs.merge(topHits, perLeafResults); + totalHits = topDocs.totalHits; + return topDocs; + } +} diff --git a/src/main/java/com/yelp/nrtsearch/server/search/SearchRequestProcessor.java b/src/main/java/com/yelp/nrtsearch/server/search/SearchRequestProcessor.java index 86c7ae713..3d4cad04b 100644 --- a/src/main/java/com/yelp/nrtsearch/server/search/SearchRequestProcessor.java +++ b/src/main/java/com/yelp/nrtsearch/server/search/SearchRequestProcessor.java @@ -76,6 +76,9 @@ import org.apache.lucene.queryparser.classic.QueryParserBase; import org.apache.lucene.queryparser.simple.SimpleQueryParser; import org.apache.lucene.search.*; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.util.QueryBuilder; /** @@ -270,13 +273,32 @@ private static Query buildKnnQuery(KnnQuery knnQuery, IndexState indexState) { throw new IllegalArgumentException("Field does not support vector search: " + field); } + // Path to nested document containing this field + String fieldNestedPath = IndexState.getFieldBaseNestedPath(field, indexState); + // Path to parent document, this will be null if the field is in the root document + String parentNestedPath = IndexState.getFieldBaseNestedPath(fieldNestedPath, indexState); + Query filterQuery; if (knnQuery.hasFilter()) { filterQuery = QueryNodeMapper.getInstance().getQuery(knnQuery.getFilter(), indexState); } else { filterQuery = null; } - return vectorQueryable.getKnnQuery(knnQuery, filterQuery); + + BitSetProducer parentBitSetProducer = null; + if (parentNestedPath != null) { + Query parentQuery = + QueryNodeMapper.getInstance().getNestedPathQuery(indexState, parentNestedPath); + parentBitSetProducer = new QueryBitSetProducer(parentQuery); + if (filterQuery != null) { + // Filter query is applied to the parent document only + filterQuery = + QueryNodeMapper.getInstance() + .applyQueryNestedPath(filterQuery, indexState, parentNestedPath); + filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSetProducer); + } + } + return vectorQueryable.getKnnQuery(knnQuery, filterQuery, parentBitSetProducer); } /** diff --git a/src/test/java/com/yelp/nrtsearch/server/field/VectorFieldDefTest.java b/src/test/java/com/yelp/nrtsearch/server/field/VectorFieldDefTest.java index f9821f50b..1157f596a 100644 --- a/src/test/java/com/yelp/nrtsearch/server/field/VectorFieldDefTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/field/VectorFieldDefTest.java @@ -54,6 +54,7 @@ public class VectorFieldDefTest extends ServerTestCase { @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); public static final String VECTOR_SEARCH_INDEX_NAME = "vector_search_index"; + public static final String NESTED_VECTOR_SEARCH_INDEX_NAME = "nested_vector_search_index"; private static final String FIELD_NAME = "vector_field"; private static final String FIELD_TYPE = "VECTOR"; private static final List VECTOR_FIELD_VALUES = @@ -61,7 +62,7 @@ public class VectorFieldDefTest extends ServerTestCase { @Override protected List getIndices() { - return List.of(DEFAULT_TEST_INDEX, VECTOR_SEARCH_INDEX_NAME); + return List.of(DEFAULT_TEST_INDEX, VECTOR_SEARCH_INDEX_NAME, NESTED_VECTOR_SEARCH_INDEX_NAME); } private Map getFieldsMapForOneDocument(String value) { @@ -163,6 +164,39 @@ private void indexVectorSearchDocs() throws Exception { } } + private void indexNestedVectorSearchDocs() throws Exception { + List docs = new ArrayList<>(); + docs.add( + AddDocumentRequest.newBuilder() + .setIndexName(NESTED_VECTOR_SEARCH_INDEX_NAME) + .putFields("id", MultiValuedField.newBuilder().addValue("0").build()) + .putFields("filter_field", MultiValuedField.newBuilder().addValue("1").build()) + .putFields( + "nested_object", + MultiValuedField.newBuilder() + .addValue( + "{\"float_vector\": \"[0.25, 0.5, 0.1]\", \"byte_vector\": \"[-50, -5, 0]\"}") + .addValue( + "{\"float_vector\": \"[0.2, 0.4, 0.3]\", \"byte_vector\": \"[-8, -10, -1]\"}") + .build()) + .build()); + docs.add( + AddDocumentRequest.newBuilder() + .setIndexName(NESTED_VECTOR_SEARCH_INDEX_NAME) + .putFields("id", MultiValuedField.newBuilder().addValue("1").build()) + .putFields("filter_field", MultiValuedField.newBuilder().addValue("2").build()) + .putFields( + "nested_object", + MultiValuedField.newBuilder() + .addValue( + "{\"float_vector\": \"[0.75, 0.9, 0.6]\", \"byte_vector\": \"[50, 5, 0]\"}") + .addValue( + "{\"float_vector\": \"[0.7, 0.8, 0.9]\", \"byte_vector\": \"[8, 10, 1]\"}") + .build()) + .build()); + addDocuments(docs.stream()); + } + private String createVectorString(Random random, int size, boolean normalize) { List vector = new ArrayList<>(); for (int i = 0; i < size; ++i) { @@ -201,6 +235,8 @@ public FieldDefRequest getIndexDef(String name) throws IOException { return getFieldsFromResourceFile("/field/registerFieldsVector.json"); } else if (VECTOR_SEARCH_INDEX_NAME.equals(name)) { return getFieldsFromResourceFile("/field/registerFieldsVectorSearch.json"); + } else if (NESTED_VECTOR_SEARCH_INDEX_NAME.equals(name)) { + return getFieldsFromResourceFile("/field/registerFieldsNestedVectorSearch.json"); } throw new IllegalArgumentException("Unknown index name: " + name); } @@ -212,6 +248,8 @@ public void initIndex(String name) throws Exception { addDocuments(documents.stream()); } else if (VECTOR_SEARCH_INDEX_NAME.equals(name)) { indexVectorSearchDocs(); + } else if (NESTED_VECTOR_SEARCH_INDEX_NAME.equals(name)) { + indexNestedVectorSearchDocs(); } else { throw new IllegalArgumentException("Unknown index name: " + name); } @@ -1427,4 +1465,151 @@ public void testInvalid4BitsOddDimensions() { e.getMessage()); } } + + @Test + public void testNestedFloatVectorSearch() { + SearchResponse response = + getGrpcServer() + .getBlockingStub() + .search( + SearchRequest.newBuilder() + .setIndexName(NESTED_VECTOR_SEARCH_INDEX_NAME) + .setStartHit(0) + .setTopHits(10) + .addRetrieveFields("id") + .addKnn( + KnnQuery.newBuilder() + .setField("nested_object.float_vector") + .addAllQueryVector(List.of(0.6f, 0.5f, 0.75f)) + .setNumCandidates(10) + .setK(5) + .build()) + .build()); + assertEquals(2, response.getHitsCount()); + assertEquals("1", response.getHits(0).getFieldsOrThrow("id").getFieldValue(0).getTextValue()); + assertEquals( + VectorSimilarityFunction.COSINE.compare( + new float[] {0.6f, 0.5f, 0.75f}, new float[] {0.7f, 0.8f, 0.9f}), + response.getHits(0).getScore(), + 0.0001); + assertEquals("0", response.getHits(1).getFieldsOrThrow("id").getFieldValue(0).getTextValue()); + assertEquals( + VectorSimilarityFunction.COSINE.compare( + new float[] {0.6f, 0.5f, 0.75f}, new float[] {0.2f, 0.4f, 0.3f}), + response.getHits(1).getScore(), + 0.0001); + + assertEquals(1, response.getDiagnostics().getVectorDiagnosticsCount()); + VectorDiagnostics vectorDiagnostics = response.getDiagnostics().getVectorDiagnostics(0); + assertTrue(vectorDiagnostics.getSearchTimeMs() > 0.0); + assertEquals(4, vectorDiagnostics.getTotalHits().getValue()); + } + + @Test + public void testNestedFloatVectorSearchWithFilter() { + SearchResponse response = + getGrpcServer() + .getBlockingStub() + .search( + SearchRequest.newBuilder() + .setIndexName(NESTED_VECTOR_SEARCH_INDEX_NAME) + .setStartHit(0) + .setTopHits(10) + .addRetrieveFields("id") + .addKnn( + KnnQuery.newBuilder() + .setField("nested_object.float_vector") + .addAllQueryVector(List.of(0.6f, 0.5f, 0.75f)) + .setNumCandidates(10) + .setK(5) + .setFilter( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("filter_field") + .setIntValue(1) + .build()) + .build()) + .build()) + .build()); + assertEquals(1, response.getHitsCount()); + assertEquals("0", response.getHits(0).getFieldsOrThrow("id").getFieldValue(0).getTextValue()); + assertEquals( + VectorSimilarityFunction.COSINE.compare( + new float[] {0.6f, 0.5f, 0.75f}, new float[] {0.2f, 0.4f, 0.3f}), + response.getHits(0).getScore(), + 0.0001); + } + + @Test + public void testNestedByteVectorSearch() { + SearchResponse response = + getGrpcServer() + .getBlockingStub() + .search( + SearchRequest.newBuilder() + .setIndexName(NESTED_VECTOR_SEARCH_INDEX_NAME) + .setStartHit(0) + .setTopHits(10) + .addRetrieveFields("id") + .addKnn( + KnnQuery.newBuilder() + .setField("nested_object.byte_vector") + .setQueryByteVector(ByteString.copyFrom(new byte[] {1, 2, 3})) + .setNumCandidates(10) + .setK(5) + .build()) + .build()); + assertEquals(2, response.getHitsCount()); + assertEquals("1", response.getHits(0).getFieldsOrThrow("id").getFieldValue(0).getTextValue()); + assertEquals( + VectorSimilarityFunction.COSINE.compare(new byte[] {1, 2, 3}, new byte[] {8, 10, 1}), + response.getHits(0).getScore(), + 0.0001); + assertEquals("0", response.getHits(1).getFieldsOrThrow("id").getFieldValue(0).getTextValue()); + assertEquals( + VectorSimilarityFunction.COSINE.compare(new byte[] {1, 2, 3}, new byte[] {-50, -5, 0}), + response.getHits(1).getScore(), + 0.0001); + + assertEquals(1, response.getDiagnostics().getVectorDiagnosticsCount()); + VectorDiagnostics vectorDiagnostics = response.getDiagnostics().getVectorDiagnostics(0); + assertTrue(vectorDiagnostics.getSearchTimeMs() > 0.0); + assertEquals(4, vectorDiagnostics.getTotalHits().getValue()); + } + + @Test + public void testNestedByteVectorSearchWithFilter() { + SearchResponse response = + getGrpcServer() + .getBlockingStub() + .search( + SearchRequest.newBuilder() + .setIndexName(NESTED_VECTOR_SEARCH_INDEX_NAME) + .setStartHit(0) + .setTopHits(10) + .addRetrieveFields("id") + .addKnn( + KnnQuery.newBuilder() + .setField("nested_object.byte_vector") + .setQueryByteVector(ByteString.copyFrom(new byte[] {1, 2, 3})) + .setNumCandidates(10) + .setK(5) + .setFilter( + Query.newBuilder() + .setTermQuery( + TermQuery.newBuilder() + .setField("filter_field") + .setIntValue(1) + .build()) + .build()) + .build()) + .build()); + assertEquals(1, response.getHitsCount()); + assertEquals("0", response.getHits(0).getFieldsOrThrow("id").getFieldValue(0).getTextValue()); + assertEquals( + VectorSimilarityFunction.COSINE.compare(new byte[] {1, 2, 3}, new byte[] {-50, -5, 0}), + response.getHits(0).getScore(), + 0.0001); + } } diff --git a/src/test/java/com/yelp/nrtsearch/server/index/IndexStateTest.java b/src/test/java/com/yelp/nrtsearch/server/index/IndexStateTest.java new file mode 100644 index 000000000..133f486ab --- /dev/null +++ b/src/test/java/com/yelp/nrtsearch/server/index/IndexStateTest.java @@ -0,0 +1,122 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.index; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.yelp.nrtsearch.server.field.ObjectFieldDef; +import com.yelp.nrtsearch.server.field.TextFieldDef; +import org.junit.Test; + +public class IndexStateTest { + @Test + public void testGetFieldBaseNestedPath_null() { + IndexState mockState = mock(IndexState.class); + try { + IndexState.getFieldBaseNestedPath(null, mockState); + fail(); + } catch (NullPointerException e) { + // Expected + } + } + + @Test + public void testGetFieldBaseNestedPath_root() { + IndexState mockState = mock(IndexState.class); + String path = IndexState.getFieldBaseNestedPath(IndexState.ROOT, mockState); + assertNull(path); + } + + @Test + public void testGetFieldBaseNestedPath_inBaseDoc() { + IndexState mockState = mock(IndexState.class); + when(mockState.getFieldOrThrow("field")).thenReturn(mock(TextFieldDef.class)); + String path = IndexState.getFieldBaseNestedPath("field", mockState); + assertEquals(IndexState.ROOT, path); + } + + @Test + public void testGetFieldBaseNestedPath_objectInBaseDoc() { + IndexState mockState = mock(IndexState.class); + ObjectFieldDef mockObject = mock(ObjectFieldDef.class); + when(mockObject.isNestedDoc()).thenReturn(false); + when(mockState.getFieldOrThrow("object")).thenReturn(mockObject); + String path = IndexState.getFieldBaseNestedPath("object", mockState); + assertEquals(IndexState.ROOT, path); + } + + @Test + public void testGetFieldBaseNestedPath_nestedObjectInBaseDoc() { + IndexState mockState = mock(IndexState.class); + ObjectFieldDef mockObject = mock(ObjectFieldDef.class); + when(mockObject.isNestedDoc()).thenReturn(true); + when(mockState.getFieldOrThrow("object")).thenReturn(mockObject); + String path = IndexState.getFieldBaseNestedPath("object", mockState); + assertEquals(IndexState.ROOT, path); + } + + @Test + public void testGetFieldBaseNestedPath_fieldOfObject() { + IndexState mockState = mock(IndexState.class); + ObjectFieldDef mockObject = mock(ObjectFieldDef.class); + when(mockObject.isNestedDoc()).thenReturn(false); + when(mockState.getFieldOrThrow("object")).thenReturn(mockObject); + when(mockState.getFieldOrThrow("object.field")).thenReturn(mock(TextFieldDef.class)); + String path = IndexState.getFieldBaseNestedPath("object.field", mockState); + assertEquals(IndexState.ROOT, path); + } + + @Test + public void testGetFieldBaseNestedPath_fieldOfNestedObject() { + IndexState mockState = mock(IndexState.class); + ObjectFieldDef mockObject = mock(ObjectFieldDef.class); + when(mockObject.isNestedDoc()).thenReturn(true); + when(mockState.getFieldOrThrow("object")).thenReturn(mockObject); + when(mockState.getFieldOrThrow("object.field")).thenReturn(mock(TextFieldDef.class)); + String path = IndexState.getFieldBaseNestedPath("object.field", mockState); + assertEquals("object", path); + } + + @Test + public void testGetFieldBaseNestedPath_multipleNestedObjects() { + IndexState mockState = mock(IndexState.class); + ObjectFieldDef mockObject = mock(ObjectFieldDef.class); + when(mockObject.isNestedDoc()).thenReturn(true); + when(mockState.getFieldOrThrow("object1")).thenReturn(mockObject); + when(mockState.getFieldOrThrow("object1.object2")).thenReturn(mockObject); + when(mockState.getFieldOrThrow("object1.object2.field")).thenReturn(mock(TextFieldDef.class)); + String path = IndexState.getFieldBaseNestedPath("object1.object2.field", mockState); + assertEquals("object1.object2", path); + } + + @Test + public void testGetFieldBaseNestedPath_unknownField() { + IndexState mockState = mock(IndexState.class); + ObjectFieldDef mockObject = mock(ObjectFieldDef.class); + when(mockObject.isNestedDoc()).thenReturn(true); + when(mockState.getFieldOrThrow("object")).thenThrow(new IllegalArgumentException("error")); + try { + IndexState.getFieldBaseNestedPath("object.field", mockState); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("error", e.getMessage()); + } + } +} diff --git a/src/test/resources/field/registerFieldsNestedVectorSearch.json b/src/test/resources/field/registerFieldsNestedVectorSearch.json new file mode 100644 index 000000000..27dfac999 --- /dev/null +++ b/src/test/resources/field/registerFieldsNestedVectorSearch.json @@ -0,0 +1,39 @@ +{ + "indexName": "nested_vector_search_index", + "field": [ + { + "name": "id", + "type": "_ID", + "storeDocValues": true + }, + { + "name": "filter_field", + "type": "INT", + "search": true, + "storeDocValues": true + }, + { + "name": "nested_object", + "type": "OBJECT", + "nestedDoc": true, + "multiValued": true, + "childFields": [ + { + "name": "float_vector", + "type": "VECTOR", + "search": true, + "vectorDimensions": 3, + "vectorSimilarity": "cosine" + }, + { + "name": "byte_vector", + "type": "VECTOR", + "search": true, + "vectorDimensions": 3, + "vectorSimilarity": "cosine", + "vectorElementType": "VECTOR_ELEMENT_BYTE" + } + ] + } + ] +} \ No newline at end of file