Skip to content

Commit 792e70d

Browse files
Copilotlvca
andauthored
Add filtered vector search support to LSMVectorIndex (#3072)
* Initial plan * Add filtered search support to LSMVectorIndex with RID-based filtering Co-authored-by: lvca <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: lvca <[email protected]>
1 parent 9ee773a commit 792e70d

File tree

2 files changed

+162
-2
lines changed

2 files changed

+162
-2
lines changed

engine/src/main/java/com/arcadedb/index/vector/LSMVectorIndex.java

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,40 @@ public interface GraphBuildCallback {
140140
void onGraphBuildProgress(String phase, int processedNodes, int totalNodes, long vectorAccesses);
141141
}
142142

143+
/**
144+
* Custom Bits implementation for filtering vector search by RID.
145+
* Maps graph ordinals to vector IDs, then checks if the corresponding RID is in the allowed set.
146+
*/
147+
private class RIDBitsFilter implements Bits {
148+
private final Set<RID> allowedRIDs;
149+
private final int[] ordinalToVectorIdSnapshot;
150+
private final VectorLocationIndex vectorIndexSnapshot;
151+
152+
RIDBitsFilter(final Set<RID> allowedRIDs, final int[] ordinalToVectorIdSnapshot, final VectorLocationIndex vectorIndexSnapshot) {
153+
this.allowedRIDs = allowedRIDs;
154+
this.ordinalToVectorIdSnapshot = ordinalToVectorIdSnapshot;
155+
this.vectorIndexSnapshot = vectorIndexSnapshot;
156+
}
157+
158+
@Override
159+
public boolean get(final int ordinal) {
160+
// Check if ordinal is within bounds
161+
if (ordinal < 0 || ordinal >= ordinalToVectorIdSnapshot.length)
162+
return false;
163+
164+
// Map ordinal to vector ID
165+
final int vectorId = ordinalToVectorIdSnapshot[ordinal];
166+
167+
// Get the RID for this vector ID
168+
final VectorLocationIndex.VectorLocation loc = vectorIndexSnapshot.getLocation(vectorId);
169+
if (loc == null || loc.deleted)
170+
return false;
171+
172+
// Check if this RID is in the allowed set
173+
return allowedRIDs.contains(loc.rid);
174+
}
175+
}
176+
143177
/**
144178
* Comparable wrapper for float[] to use in transaction tracking.
145179
* Vectors are compared by their hash code for uniqueness in the transaction map.
@@ -1803,6 +1837,21 @@ private MutablePage createNewVectorDataPage(final int pageNum) {
18031837
* @return List of pairs containing RID and similarity score
18041838
*/
18051839
public List<Pair<RID, Float>> findNeighborsFromVector(final float[] queryVector, final int k) {
1840+
return findNeighborsFromVector(queryVector, k, null);
1841+
}
1842+
1843+
/**
1844+
* Search for k nearest neighbors to the given vector within a filtered set of RIDs.
1845+
* This method allows restricting the search space to specific records, useful for
1846+
* filtering by user ID, category, or other criteria during graph traversal.
1847+
*
1848+
* @param queryVector The query vector to search for
1849+
* @param k The number of neighbors to return
1850+
* @param allowedRIDs Optional set of RIDs to restrict search to (null means no filtering)
1851+
*
1852+
* @return List of pairs containing RID and similarity score
1853+
*/
1854+
public List<Pair<RID, Float>> findNeighborsFromVector(final float[] queryVector, final int k, final Set<RID> allowedRIDs) {
18061855
if (queryVector == null)
18071856
throw new IllegalArgumentException("Query vector cannot be null");
18081857

@@ -1859,14 +1908,18 @@ public List<Pair<RID, Float>> findNeighborsFromVector(final float[] queryVector,
18591908
this // Pass LSM index reference for quantization support
18601909
);
18611910

1862-
// Perform search
1911+
// Perform search with optional RID filtering
1912+
final Bits bitsFilter = (allowedRIDs != null && !allowedRIDs.isEmpty())
1913+
? new RIDBitsFilter(allowedRIDs, ordinalToVectorId, vectorIndex)
1914+
: Bits.ALL;
1915+
18631916
final SearchResult searchResult = GraphSearcher.search(
18641917
queryVectorFloat,
18651918
k,
18661919
vectors,
18671920
metadata.similarityFunction,
18681921
graphIndex,
1869-
Bits.ALL
1922+
bitsFilter
18701923
);
18711924

18721925
LogManager.instance().log(this, Level.INFO,

engine/src/test/java/com/arcadedb/index/vector/LSMVectorIndexTest.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,6 +2101,113 @@ CREATE INDEX ON CycleTest (vec) LSM_VECTOR
21012101
deleteDirectory(dbDir);
21022102
}
21032103

2104+
@Test
2105+
void filteredSearchByRID() {
2106+
database.transaction(() -> {
2107+
// Create the schema
2108+
database.command("sql", "CREATE DOCUMENT TYPE FilteredDoc");
2109+
database.command("sql", "CREATE PROPERTY FilteredDoc.id STRING");
2110+
database.command("sql", "CREATE PROPERTY FilteredDoc.category STRING");
2111+
database.command("sql", "CREATE PROPERTY FilteredDoc.embedding ARRAY_OF_FLOATS");
2112+
2113+
// Create the LSM_VECTOR index
2114+
database.command("sql", """
2115+
CREATE INDEX ON FilteredDoc (embedding) LSM_VECTOR
2116+
METADATA {
2117+
"dimensions": 3,
2118+
"similarity": "COSINE",
2119+
"maxConnections": 8,
2120+
"beamWidth": 50
2121+
}""");
2122+
});
2123+
2124+
// Create test data with different categories
2125+
final List<com.arcadedb.database.RID> categoryARIDs = new ArrayList<>();
2126+
final List<com.arcadedb.database.RID> categoryBRIDs = new ArrayList<>();
2127+
2128+
database.transaction(() -> {
2129+
for (int i = 0; i < 20; i++) {
2130+
final var doc = database.newDocument("FilteredDoc");
2131+
doc.set("id", "doc" + i);
2132+
doc.set("category", i < 10 ? "A" : "B");
2133+
2134+
// Create vectors with some pattern based on category
2135+
final float[] vector = new float[3];
2136+
if (i < 10) {
2137+
// Category A: vectors around [1, 1, 1]
2138+
vector[0] = 1.0f + (i * 0.1f);
2139+
vector[1] = 1.0f + (i * 0.1f);
2140+
vector[2] = 1.0f + (i * 0.1f);
2141+
} else {
2142+
// Category B: vectors around [10, 10, 10]
2143+
vector[0] = 10.0f + ((i - 10) * 0.1f);
2144+
vector[1] = 10.0f + ((i - 10) * 0.1f);
2145+
vector[2] = 10.0f + ((i - 10) * 0.1f);
2146+
}
2147+
doc.set("embedding", vector);
2148+
2149+
final com.arcadedb.database.RID rid = doc.save().getIdentity();
2150+
if (i < 10) {
2151+
categoryARIDs.add(rid);
2152+
} else {
2153+
categoryBRIDs.add(rid);
2154+
}
2155+
}
2156+
});
2157+
2158+
// Get the index
2159+
final com.arcadedb.index.TypeIndex typeIndex = (com.arcadedb.index.TypeIndex) database.getSchema()
2160+
.getIndexByName("FilteredDoc[embedding]");
2161+
final LSMVectorIndex index = (LSMVectorIndex) typeIndex.getIndexesOnBuckets()[0];
2162+
2163+
database.transaction(() -> {
2164+
// Query vector close to category A
2165+
final float[] queryVector = {1.5f, 1.5f, 1.5f};
2166+
2167+
// Test 1: Search without filter - should return results from both categories
2168+
final List<com.arcadedb.utility.Pair<com.arcadedb.database.RID, Float>> unfilteredResults =
2169+
index.findNeighborsFromVector(queryVector, 10);
2170+
assertThat(unfilteredResults).as("Unfiltered search should return results").isNotEmpty();
2171+
assertThat(unfilteredResults.size()).as("Should return up to 10 results").isLessThanOrEqualTo(10);
2172+
2173+
// Test 2: Search with filter for category A only
2174+
final Set<com.arcadedb.database.RID> allowedRIDs = new HashSet<>(categoryARIDs);
2175+
final List<com.arcadedb.utility.Pair<com.arcadedb.database.RID, Float>> filteredResults =
2176+
index.findNeighborsFromVector(queryVector, 10, allowedRIDs);
2177+
2178+
assertThat(filteredResults).as("Filtered search should return results").isNotEmpty();
2179+
assertThat(filteredResults.size()).as("Should return at most 10 results").isLessThanOrEqualTo(10);
2180+
2181+
// Verify all results are from the allowed set
2182+
for (final var result : filteredResults) {
2183+
assertThat(allowedRIDs).as("Result RID should be in allowed set").contains(result.getFirst());
2184+
}
2185+
2186+
// Test 3: Search with filter for category B only
2187+
final Set<com.arcadedb.database.RID> categoryBSet = new HashSet<>(categoryBRIDs);
2188+
final List<com.arcadedb.utility.Pair<com.arcadedb.database.RID, Float>> categoryBResults =
2189+
index.findNeighborsFromVector(queryVector, 10, categoryBSet);
2190+
2191+
// Since query vector is close to category A, but we filter to category B,
2192+
// we should still get results (from category B), just with higher distances
2193+
assertThat(categoryBResults).as("Filtered search for category B should return results").isNotEmpty();
2194+
2195+
for (final var result : categoryBResults) {
2196+
assertThat(categoryBSet).as("Result RID should be from category B").contains(result.getFirst());
2197+
}
2198+
2199+
// Test 4: Empty filter should work like unfiltered
2200+
final List<com.arcadedb.utility.Pair<com.arcadedb.database.RID, Float>> emptyFilterResults =
2201+
index.findNeighborsFromVector(queryVector, 10, new HashSet<>());
2202+
assertThat(emptyFilterResults).as("Empty filter should return results like unfiltered").isNotEmpty();
2203+
2204+
// Test 5: Null filter should work like unfiltered
2205+
final List<com.arcadedb.utility.Pair<com.arcadedb.database.RID, Float>> nullFilterResults =
2206+
index.findNeighborsFromVector(queryVector, 10, null);
2207+
assertThat(nullFilterResults).as("Null filter should return results like unfiltered").isNotEmpty();
2208+
});
2209+
}
2210+
21042211
/**
21052212
* Helper method to recursively delete a directory using Files.walk() API
21062213
*/

0 commit comments

Comments
 (0)