Skip to content

Commit

Permalink
Add vector search support for fields in nested objects
Browse files Browse the repository at this point in the history
  • Loading branch information
aprudhomme committed Dec 18, 2024
1 parent 8ead293 commit f12be15
Show file tree
Hide file tree
Showing 9 changed files with 549 additions and 10 deletions.
44 changes: 37 additions & 7 deletions src/main/java/com/yelp/nrtsearch/server/field/VectorFieldDef.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import com.yelp.nrtsearch.server.grpc.Field;
import com.yelp.nrtsearch.server.grpc.KnnQuery;
import com.yelp.nrtsearch.server.grpc.VectorIndexingOptions;
import com.yelp.nrtsearch.server.query.vector.NrtDiversifyingChildrenByteKnnVectorQuery;
import com.yelp.nrtsearch.server.query.vector.NrtDiversifyingChildrenFloatKnnVectorQuery;
import com.yelp.nrtsearch.server.query.vector.NrtKnnByteVectorQuery;
import com.yelp.nrtsearch.server.query.vector.NrtKnnFloatVectorQuery;
import com.yelp.nrtsearch.server.vector.ByteVectorType;
Expand All @@ -51,6 +53,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;

Expand Down Expand Up @@ -113,9 +116,15 @@ enum VectorSearchType {
* @param k number of nearest neighbors to return
* @param filterQuery filter query
* @param numCandidates number of candidates to search per segment
* @param parentBitSetProducer parent bit set producer for searching nested fields, or null
* @return knn query
*/
abstract Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates);
abstract Query getTypeKnnQuery(
KnnQuery knnQuery,
int k,
Query filterQuery,
int numCandidates,
BitSetProducer parentBitSetProducer);

private static VectorSimilarityFunction getSimilarityFunction(String vectorSimilarity) {
VectorSimilarityFunction similarityFunction = SIMILARITY_FUNCTION_MAP.get(vectorSimilarity);
Expand Down Expand Up @@ -307,7 +316,8 @@ public void parseDocumentField(
}

@Override
public Query getKnnQuery(KnnQuery knnQuery, Query filterQuery) {
public Query getKnnQuery(
KnnQuery knnQuery, Query filterQuery, BitSetProducer parentBitSetProducer) {
if (!isSearchable()) {
throw new IllegalArgumentException("Vector field is not searchable: " + getName());
}
Expand All @@ -324,7 +334,7 @@ public Query getKnnQuery(KnnQuery knnQuery, Query filterQuery) {
throw new IllegalArgumentException("Vector search numCandidates > " + NUM_CANDIDATES_LIMIT);
}

return getTypeKnnQuery(knnQuery, k, filterQuery, numCandidates);
return getTypeKnnQuery(knnQuery, k, filterQuery, numCandidates, parentBitSetProducer);
}

/** Field class for 'FLOAT' vector field type. */
Expand Down Expand Up @@ -420,7 +430,12 @@ static byte[] convertFloatArrToBytes(float[] floatArr) {
}

@Override
Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates) {
Query getTypeKnnQuery(
KnnQuery knnQuery,
int k,
Query filterQuery,
int numCandidates,
BitSetProducer parentBitSetProducer) {
if (knnQuery.getQueryVectorCount() != getVectorDimensions()) {
throw new IllegalArgumentException(
"Invalid query vector size, expected: "
Expand All @@ -433,7 +448,12 @@ Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandid
queryVector[i] = knnQuery.getQueryVector(i);
}
validateVectorForSearch(queryVector);
return new NrtKnnFloatVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
if (parentBitSetProducer != null) {
return new NrtDiversifyingChildrenFloatKnnVectorQuery(
getName(), queryVector, filterQuery, k, numCandidates, parentBitSetProducer);
} else {
return new NrtKnnFloatVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
}
}

/**
Expand Down Expand Up @@ -560,7 +580,12 @@ byte[] parseVectorFieldToByteArr(String fieldValueJson) {
}

@Override
Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandidates) {
Query getTypeKnnQuery(
KnnQuery knnQuery,
int k,
Query filterQuery,
int numCandidates,
BitSetProducer parentBitSetProducer) {
if (knnQuery.getQueryByteVector().size() != getVectorDimensions()) {
throw new IllegalArgumentException(
"Invalid query byte vector size, expected: "
Expand All @@ -570,7 +595,12 @@ Query getTypeKnnQuery(KnnQuery knnQuery, int k, Query filterQuery, int numCandid
}
byte[] queryVector = knnQuery.getQueryByteVector().toByteArray();
validateVectorForSearch(queryVector);
return new NrtKnnByteVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
if (parentBitSetProducer != null) {
return new NrtDiversifyingChildrenByteKnnVectorQuery(
getName(), queryVector, filterQuery, k, numCandidates, parentBitSetProducer);
} else {
return new NrtKnnByteVectorQuery(getName(), queryVector, k, filterQuery, numCandidates);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.yelp.nrtsearch.server.field.FieldDef;
import com.yelp.nrtsearch.server.grpc.KnnQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;

/** Trait interface for {@link FieldDef} types that can be queried by a {@link KnnQuery}. */
public interface VectorQueryable {
Expand All @@ -26,7 +27,9 @@ public interface VectorQueryable {
*
* @param knnQuery knn query configuration
* @param filterQuery query to filter knn search, or null
* @param parentBitSetProducer bit set producer for parent documents when searching a nested
* field, or null
* @return lucene knn query
*/
Query getKnnQuery(KnnQuery knnQuery, Query filterQuery);
Query getKnnQuery(KnnQuery knnQuery, Query filterQuery, BitSetProducer parentBitSetProducer);
}
24 changes: 24 additions & 0 deletions src/main/java/com/yelp/nrtsearch/server/index/IndexState.java
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,30 @@ public String resolveQueryNestedPath(String path) {
throw new IllegalArgumentException("Nested path is not a nested object field: " + path);
}

/**
* Get the base path for the nested document containing the field at the given path. For fields in
* the base document, this returns _root. The base nested path for _root is null.
*
* @param path field path
* @return nested base path, or null
*/
public static String getFieldBaseNestedPath(String path, IndexState indexState) {
Objects.requireNonNull(path, "path cannot be null");
if (path.equals(IndexState.ROOT)) {
return null;
}

String currentPath = path;
while (currentPath.contains(".")) {
currentPath = currentPath.substring(0, currentPath.lastIndexOf("."));
FieldDef fieldDef = indexState.getFieldOrThrow(currentPath);
if (fieldDef instanceof ObjectFieldDef objFieldDef && objFieldDef.isNestedDoc()) {
return currentPath;
}
}
return IndexState.ROOT;
}

/** Get index state info. */
public abstract IndexStateInfo getIndexStateInfo();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.query.vector;

import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;

/**
* A {@link DiversifyingChildrenByteKnnVectorQuery} that has its functionality slightly modified.
* The {@link TotalHits} from the vector search are available after the query has been rewritten
* using the {@link WithVectorTotalHits} interface. The results merging has also been modified to
* produce the top k hits from the top numCandidates hits from each leaf.
*/
public class NrtDiversifyingChildrenByteKnnVectorQuery
extends DiversifyingChildrenByteKnnVectorQuery implements WithVectorTotalHits {
private final int topHits;
private TotalHits totalHits;

public NrtDiversifyingChildrenByteKnnVectorQuery(
String field,
byte[] target,
Query filter,
int k,
int numCandidates,
BitSetProducer parentsFilter) {
super(field, target, filter, numCandidates, parentsFilter);
this.topHits = k;
}

@Override
public TotalHits getTotalHits() {
return totalHits;
}

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs topDocs = TopDocs.merge(topHits, perLeafResults);
totalHits = topDocs.totalHits;
return topDocs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.query.vector;

import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;

/**
* A {@link DiversifyingChildrenFloatKnnVectorQuery} that has its functionality slightly modified.
* The {@link TotalHits} from the vector search are available after the query has been rewritten
* using the {@link WithVectorTotalHits} interface. The results merging has also been modified to
* produce the top k hits from the top numCandidates hits from each leaf.
*/
public class NrtDiversifyingChildrenFloatKnnVectorQuery
extends DiversifyingChildrenFloatKnnVectorQuery implements WithVectorTotalHits {
private final int topHits;
private TotalHits totalHits;

public NrtDiversifyingChildrenFloatKnnVectorQuery(
String field,
float[] target,
Query filter,
int k,
int numCandidates,
BitSetProducer parentsFilter) {
super(field, target, filter, numCandidates, parentsFilter);
this.topHits = k;
}

@Override
public TotalHits getTotalHits() {
return totalHits;
}

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs topDocs = TopDocs.merge(topHits, perLeafResults);
totalHits = topDocs.totalHits;
return topDocs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@
import org.apache.lucene.queryparser.classic.QueryParserBase;
import org.apache.lucene.queryparser.simple.SimpleQueryParser;
import org.apache.lucene.search.*;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.QueryBitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.apache.lucene.util.QueryBuilder;

/**
Expand Down Expand Up @@ -270,13 +273,32 @@ private static Query buildKnnQuery(KnnQuery knnQuery, IndexState indexState) {
throw new IllegalArgumentException("Field does not support vector search: " + field);
}

// Path to nested document containing this field
String fieldNestedPath = IndexState.getFieldBaseNestedPath(field, indexState);
// Path to parent document, this will be null if the field is in the root document
String parentNestedPath = IndexState.getFieldBaseNestedPath(fieldNestedPath, indexState);

Query filterQuery;
if (knnQuery.hasFilter()) {
filterQuery = QueryNodeMapper.getInstance().getQuery(knnQuery.getFilter(), indexState);
} else {
filterQuery = null;
}
return vectorQueryable.getKnnQuery(knnQuery, filterQuery);

BitSetProducer parentBitSetProducer = null;
if (parentNestedPath != null) {
Query parentQuery =
QueryNodeMapper.getInstance().getNestedPathQuery(indexState, parentNestedPath);
parentBitSetProducer = new QueryBitSetProducer(parentQuery);
if (filterQuery != null) {
// Filter query is applied to the parent document only
filterQuery =
QueryNodeMapper.getInstance()
.applyQueryNestedPath(filterQuery, indexState, parentNestedPath);
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSetProducer);
}
}
return vectorQueryable.getKnnQuery(knnQuery, filterQuery, parentBitSetProducer);
}

/**
Expand Down
Loading

0 comments on commit f12be15

Please sign in to comment.