Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector search support for fields in nested objects #796

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/main/java/com/yelp/nrtsearch/server/field/VectorFieldDef.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import com.yelp.nrtsearch.server.grpc.Field;
import com.yelp.nrtsearch.server.grpc.KnnQuery;
import com.yelp.nrtsearch.server.grpc.VectorIndexingOptions;
import com.yelp.nrtsearch.server.query.vector.NrtDiversifyingChildrenByteKnnVectorQuery;
import com.yelp.nrtsearch.server.query.vector.NrtDiversifyingChildrenFloatKnnVectorQuery;
import com.yelp.nrtsearch.server.query.vector.NrtKnnByteVectorQuery;
import com.yelp.nrtsearch.server.query.vector.NrtKnnFloatVectorQuery;
import com.yelp.nrtsearch.server.vector.ByteVectorType;
Expand All @@ -51,6 +53,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;

Expand Down Expand Up @@ -113,9 +116,15 @@ enum VectorSearchType {
* @param k number of nearest neighbors to return
* @param filterQuery filter query
* @param numCandidates number of candidates to search per segment
* @param parentBitSetProducer parent bit set producer for searching nested fields, or null
* @return knn query
*/
abstract Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates);
abstract Query getTypeKnnQuery(
KnnQuery knnQuery,
int k,
Query filterQuery,
int numCandidates,
BitSetProducer parentBitSetProducer);

private static VectorSimilarityFunction getSimilarityFunction(String vectorSimilarity) {
VectorSimilarityFunction similarityFunction = SIMILARITY_FUNCTION_MAP.get(vectorSimilarity);
Expand Down Expand Up @@ -307,7 +316,8 @@ public void parseDocumentField(
}

@Override
public Query getKnnQuery(KnnQuery knnQuery, Query filterQuery) {
public Query getKnnQuery(
KnnQuery knnQuery, Query filterQuery, BitSetProducer parentBitSetProducer) {
if (!isSearchable()) {
throw new IllegalArgumentException("Vector field is not searchable: " + getName());
}
Expand All @@ -324,7 +334,7 @@ public Query getKnnQuery(KnnQuery knnQuery, Query filterQuery) {
throw new IllegalArgumentException("Vector search numCandidates > " + NUM_CANDIDATES_LIMIT);
}

return getTypeKnnQuery(knnQuery, k, filterQuery, numCandidates);
return getTypeKnnQuery(knnQuery, k, filterQuery, numCandidates, parentBitSetProducer);
}

/** Field class for 'FLOAT' vector field type. */
Expand Down Expand Up @@ -420,7 +430,12 @@ static byte[] convertFloatArrToBytes(float[] floatArr) {
}

@Override
Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates) {
Query getTypeKnnQuery(
KnnQuery knnQuery,
int k,
Query filterQuery,
int numCandidates,
BitSetProducer parentBitSetProducer) {
if (knnQuery.getQueryVectorCount() != getVectorDimensions()) {
throw new IllegalArgumentException(
"Invalid query vector size, expected: "
Expand All @@ -433,7 +448,12 @@ Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandid
queryVector[i] = knnQuery.getQueryVector(i);
}
validateVectorForSearch(queryVector);
return new NrtKnnFloatVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
if (parentBitSetProducer != null) {
return new NrtDiversifyingChildrenFloatKnnVectorQuery(
getName(), queryVector, filterQuery, k, numCandidates, parentBitSetProducer);
} else {
return new NrtKnnFloatVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
}
}

/**
Expand Down Expand Up @@ -560,7 +580,12 @@ byte[] parseVectorFieldToByteArr(String fieldValueJson) {
}

@Override
Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates) {
Query getTypeKnnQuery(
KnnQuery knnQuery,
int k,
Query filterQuery,
int numCandidates,
BitSetProducer parentBitSetProducer) {
if (knnQuery.getQueryByteVector().size() != getVectorDimensions()) {
throw new IllegalArgumentException(
"Invalid query byte vector size, expected: "
Expand All @@ -570,7 +595,12 @@ Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandid
}
byte[] queryVector = knnQuery.getQueryByteVector().toByteArray();
validateVectorForSearch(queryVector);
return new NrtKnnByteVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
if (parentBitSetProducer != null) {
return new NrtDiversifyingChildrenByteKnnVectorQuery(
getName(), queryVector, filterQuery, k, numCandidates, parentBitSetProducer);
} else {
return new NrtKnnByteVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.yelp.nrtsearch.server.field.FieldDef;
import com.yelp.nrtsearch.server.grpc.KnnQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;

/** Trait interface for {@link FieldDef} types that can be queried by a {@link KnnQuery}. */
public interface VectorQueryable {
Expand All @@ -26,7 +27,9 @@ public interface VectorQueryable {
*
* @param knnQuery knn query configuration
* @param filterQuery query to filter knn search, or null
* @param parentBitSetProducer bit set producer for parent documents when searching a nested
* field, or null
* @return lucene knn query
*/
Query getKnnQuery(KnnQuery knnQuery, Query filterQuery);
Query getKnnQuery(KnnQuery knnQuery, Query filterQuery, BitSetProducer parentBitSetProducer);
}
24 changes: 24 additions & 0 deletions src/main/java/com/yelp/nrtsearch/server/index/IndexState.java
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,30 @@ public String resolveQueryNestedPath(String path) {
throw new IllegalArgumentException("Nested path is not a nested object field: " + path);
}

/**
* Get the base path for the nested document containing the field at the given path. For fields in
* the base document, this returns _root. The base nested path for _root is null.
*
* @param path field path
* @return nested base path, or null
*/
public static String getFieldBaseNestedPath(String path, IndexState indexState) {
Objects.requireNonNull(path, "path cannot be null");
if (path.equals(IndexState.ROOT)) {
return null;
}

String currentPath = path;
while (currentPath.contains(".")) {
currentPath = currentPath.substring(0, currentPath.lastIndexOf("."));
FieldDef fieldDef = indexState.getFieldOrThrow(currentPath);
if (fieldDef instanceof ObjectFieldDef objFieldDef && objFieldDef.isNestedDoc()) {
return currentPath;
}
}
return IndexState.ROOT;
}

/** Get index state info. */
public abstract IndexStateInfo getIndexStateInfo();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.query.vector;

import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;

/**
* A {@link DiversifyingChildrenByteKnnVectorQuery} that has its functionality slightly modified.
* The {@link TotalHits} from the vector search are available after the query has been rewritten
* using the {@link WithVectorTotalHits} interface. The results merging has also been modified to
* produce the top k hits from the top numCandidates hits from each leaf.
*/
public class NrtDiversifyingChildrenByteKnnVectorQuery
extends DiversifyingChildrenByteKnnVectorQuery implements WithVectorTotalHits {
private final int topHits;
private TotalHits totalHits;

public NrtDiversifyingChildrenByteKnnVectorQuery(
String field,
byte[] target,
Query filter,
int k,
int numCandidates,
BitSetProducer parentsFilter) {
super(field, target, filter, numCandidates, parentsFilter);
this.topHits = k;
}

@Override
public TotalHits getTotalHits() {
return totalHits;
}

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs topDocs = TopDocs.merge(topHits, perLeafResults);
totalHits = topDocs.totalHits;
return topDocs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.query.vector;

import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;

/**
* A {@link DiversifyingChildrenFloatKnnVectorQuery} that has its functionality slightly modified.
* The {@link TotalHits} from the vector search are available after the query has been rewritten
* using the {@link WithVectorTotalHits} interface. The results merging has also been modified to
* produce the top k hits from the top numCandidates hits from each leaf.
*/
public class NrtDiversifyingChildrenFloatKnnVectorQuery
extends DiversifyingChildrenFloatKnnVectorQuery implements WithVectorTotalHits {
private final int topHits;
private TotalHits totalHits;

public NrtDiversifyingChildrenFloatKnnVectorQuery(
String field,
float[] target,
Query filter,
int k,
int numCandidates,
BitSetProducer parentsFilter) {
super(field, target, filter, numCandidates, parentsFilter);
this.topHits = k;
}

@Override
public TotalHits getTotalHits() {
return totalHits;
}

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs topDocs = TopDocs.merge(topHits, perLeafResults);
totalHits = topDocs.totalHits;
return topDocs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@
import org.apache.lucene.queryparser.classic.QueryParserBase;
import org.apache.lucene.queryparser.simple.SimpleQueryParser;
import org.apache.lucene.search.*;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.QueryBitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.apache.lucene.util.QueryBuilder;

/**
Expand Down Expand Up @@ -270,13 +273,32 @@ private static Query buildKnnQuery(KnnQuery knnQuery, IndexState indexState) {
throw new IllegalArgumentException("Field does not support vector search: " + field);
}

// Path to nested document containing this field
String fieldNestedPath = IndexState.getFieldBaseNestedPath(field, indexState);
// Path to parent document, this will be null if the field is in the root document
String parentNestedPath = IndexState.getFieldBaseNestedPath(fieldNestedPath, indexState);

Query filterQuery;
if (knnQuery.hasFilter()) {
filterQuery = QueryNodeMapper.getInstance().getQuery(knnQuery.getFilter(), indexState);
} else {
filterQuery = null;
}
return vectorQueryable.getKnnQuery(knnQuery, filterQuery);

BitSetProducer parentBitSetProducer = null;
if (parentNestedPath != null) {
Query parentQuery =
QueryNodeMapper.getInstance().getNestedPathQuery(indexState, parentNestedPath);
parentBitSetProducer = new QueryBitSetProducer(parentQuery);
if (filterQuery != null) {
// Filter query is applied to the parent document only
filterQuery =
QueryNodeMapper.getInstance()
.applyQueryNestedPath(filterQuery, indexState, parentNestedPath);
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSetProducer);
}
}
return vectorQueryable.getKnnQuery(knnQuery, filterQuery, parentBitSetProducer);
}

/**
Expand Down
Loading
Loading