Skip to content

Commit

Permalink
Add default terminate after max recall count (#790)
Browse files Browse the repository at this point in the history
Add default terminate after max recall count
  • Loading branch information
swethakann authored Nov 26, 2024
1 parent 545e5f8 commit 7f1d1b5
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 2 deletions.
4 changes: 4 additions & 0 deletions clientlib/src/main/proto/yelp/nrtsearch/luceneserver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ message LiveSettingsRequest {
int32 defaultSearchTimeoutCheckEvery = 13;
// Terminate after value to use when not specified in the search request
int32 defaultTerminateAfter = 14;
//Terminate after max recall count value to use when not specified in the search request.
int32 defaultTerminateAfterMaxRecallCount = 15;
}

// Response from Server to liveSettings
Expand Down Expand Up @@ -1446,6 +1448,8 @@ message IndexLiveSettings {
google.protobuf.BoolValue parallelFetchByField = 16;
// The number of documents/fields per parallel fetch task, default: 50
google.protobuf.Int32Value parallelFetchChunkSize = 17;
// Terminate after max recall count value to use when not specified in the search request, or 0 for none, default: 0
google.protobuf.Int32Value defaultTerminateAfterMaxRecallCount = 18;
}

// Index state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ private LiveSettingsResponse handleAsLiveSettingsV2(
.setValue(liveSettingsRequest.getDefaultTerminateAfter())
.build());
}
if (liveSettingsRequest.getDefaultTerminateAfterMaxRecallCount() >= 0) {
settingsBuilder.setDefaultTerminateAfterMaxRecallCount(
Int32Value.newBuilder()
.setValue(liveSettingsRequest.getDefaultTerminateAfterMaxRecallCount())
.build());
}
try {
updatedSettings = indexStateManager.updateLiveSettings(settingsBuilder.build(), false);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ public class ImmutableIndexState extends IndexState {
.setDefaultSearchTimeoutSec(DoubleValue.newBuilder().setValue(0).build())
.setDefaultSearchTimeoutCheckEvery(Int32Value.newBuilder().setValue(0).build())
.setDefaultTerminateAfter(Int32Value.newBuilder().setValue(0).build())
.setDefaultTerminateAfterMaxRecallCount(Int32Value.newBuilder().setValue(0).build())
.setMaxMergePreCopyDurationSec(UInt64Value.newBuilder().setValue(0))
.setVerboseMetrics(BoolValue.newBuilder().setValue(false).build())
.setParallelFetchByField(BoolValue.newBuilder().setValue(false).build())
Expand All @@ -167,6 +168,7 @@ public class ImmutableIndexState extends IndexState {
private final double defaultSearchTimeoutSec;
private final int defaultSearchTimeoutCheckEvery;
private final int defaultTerminateAfter;
private final int defaultTerminateAfterMaxRecallCount;
private final long maxMergePreCopyDurationSec;
private final boolean verboseMetrics;
private final ParallelFetchConfig parallelFetchConfig;
Expand Down Expand Up @@ -262,6 +264,8 @@ public ImmutableIndexState(
defaultSearchTimeoutCheckEvery =
mergedLiveSettingsWithLocal.getDefaultSearchTimeoutCheckEvery().getValue();
defaultTerminateAfter = mergedLiveSettingsWithLocal.getDefaultTerminateAfter().getValue();
defaultTerminateAfterMaxRecallCount =
mergedLiveSettingsWithLocal.getDefaultTerminateAfterMaxRecallCount().getValue();
maxMergePreCopyDurationSec =
mergedLiveSettingsWithLocal.getMaxMergePreCopyDurationSec().getValue();
verboseMetrics = mergedLiveSettingsWithLocal.getVerboseMetrics().getValue();
Expand Down Expand Up @@ -718,6 +722,11 @@ public int getDefaultTerminateAfter() {
return defaultTerminateAfter;
}

@Override
public int getDefaultTerminateAfterMaxRecallCount() {
return defaultTerminateAfterMaxRecallCount;
}

@Override
public int getDefaultSearchTimeoutCheckEvery() {
return defaultSearchTimeoutCheckEvery;
Expand Down Expand Up @@ -809,6 +818,9 @@ static void validateLiveSettings(IndexLiveSettings liveSettings) {
if (liveSettings.getDefaultTerminateAfter().getValue() < 0) {
throw new IllegalArgumentException("defaultTerminateAfter must be >= 0");
}
if (liveSettings.getDefaultTerminateAfterMaxRecallCount().getValue() < 0) {
throw new IllegalArgumentException("defaultTerminateAfterMaxRecallCount must be >= 0");
}
if (liveSettings.getMaxMergePreCopyDurationSec().getValue() < 0) {
throw new IllegalArgumentException("maxMergePreCopyDurationSec must be >= 0");
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/com/yelp/nrtsearch/server/index/IndexState.java
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ public abstract IndexWriterConfig getIndexWriterConfig(
/** Get the default terminate after. */
public abstract int getDefaultTerminateAfter();

/** Get the default terminate after max recall count. */
public abstract int getDefaultTerminateAfterMaxRecallCount();

/** Get the default search timeout check every. */
public abstract int getDefaultSearchTimeoutCheckEvery();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public int getTerminateAfter() {
return terminateAfter;
}

/** Max documents to count beyond terminateAfter. */
public int getTerminateAfterMaxRecallCount() {
return terminateAfterMaxRecallCount;
}

/**
* {@link Collector} implementation that wraps another collector and terminates collection after a
* certain global count of documents is reached.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ <C extends Collector> CollectorManager<? extends Collector, SearcherResult> wrap
int terminateAfterMaxRecallCount =
request.getTerminateAfterMaxRecallCount() > 0
? request.getTerminateAfterMaxRecallCount()
: 0;
: indexState.getDefaultTerminateAfterMaxRecallCount();
if (terminateAfter > 0) {
wrapped =
new TerminateAfterWrapper<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ public class LiveSettingsV2Command implements Callable<Integer> {
description = "Terminate after to use when not provided by the request")
private Integer defaultTerminateAfter;

@CommandLine.Option(
names = {"--defaultTerminateAfterMaxRecallCount"},
description = "Terminate after max recall count to use when not provided by the request")
private Integer defaultTerminateAfterMaxRecallCount;

@CommandLine.Option(
names = {"--maxMergePreCopyDurationSec"},
description = "Maximum time allowed for merge precopy in seconds")
Expand Down Expand Up @@ -192,6 +197,10 @@ public Integer call() throws Exception {
liveSettingsBuilder.setDefaultTerminateAfter(
Int32Value.newBuilder().setValue(defaultTerminateAfter).build());
}
if (defaultTerminateAfterMaxRecallCount != null) {
liveSettingsBuilder.setDefaultTerminateAfterMaxRecallCount(
Int32Value.newBuilder().setValue(defaultTerminateAfterMaxRecallCount).build());
}
if (maxMergePreCopyDurationSec != null) {
liveSettingsBuilder.setMaxMergePreCopyDurationSec(
UInt64Value.newBuilder().setValue(maxMergePreCopyDurationSec));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,8 @@ public void testSetIndexLiveSettings() throws IOException {
IndexLiveSettings.newBuilder()
.setDefaultTerminateAfter(
Int32Value.newBuilder().setValue(1000).build())
.setDefaultTerminateAfterMaxRecallCount(
Int32Value.newBuilder().setValue(1000).build())
.setSegmentsPerTier(Int32Value.newBuilder().setValue(4).build())
.setSliceMaxSegments(Int32Value.newBuilder().setValue(50).build())
.setDefaultSearchTimeoutSec(
Expand All @@ -784,6 +786,7 @@ public void testSetIndexLiveSettings() throws IOException {
IndexLiveSettings expectedSettings =
ImmutableIndexState.DEFAULT_INDEX_LIVE_SETTINGS.toBuilder()
.setDefaultTerminateAfter(Int32Value.newBuilder().setValue(1000).build())
.setDefaultTerminateAfterMaxRecallCount(Int32Value.newBuilder().setValue(1000).build())
.setSegmentsPerTier(Int32Value.newBuilder().setValue(4).build())
.setSliceMaxSegments(Int32Value.newBuilder().setValue(50).build())
.setDefaultSearchTimeoutSec(DoubleValue.newBuilder().setValue(5.1).build())
Expand Down Expand Up @@ -1647,6 +1650,7 @@ public void testLiveSettingsV1All() throws IOException {
.setDefaultSearchTimeoutSec(13.0)
.setDefaultSearchTimeoutCheckEvery(500)
.setDefaultTerminateAfter(5000)
.setDefaultTerminateAfterMaxRecallCount(6000)
.build();

LiveSettingsResponse response = primaryClient.getBlockingStub().liveSettings(request);
Expand All @@ -1665,6 +1669,7 @@ public void testLiveSettingsV1All() throws IOException {
.setDefaultSearchTimeoutSec(DoubleValue.newBuilder().setValue(13.0).build())
.setDefaultSearchTimeoutCheckEvery(Int32Value.newBuilder().setValue(500).build())
.setDefaultTerminateAfter(Int32Value.newBuilder().setValue(5000).build())
.setDefaultTerminateAfterMaxRecallCount(Int32Value.newBuilder().setValue(6000).build())
.setMaxMergePreCopyDurationSec(UInt64Value.newBuilder().setValue(0))
.setVerboseMetrics(BoolValue.newBuilder().setValue(false).build())
.setParallelFetchByField(BoolValue.newBuilder().setValue(false).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,26 @@ public void testDefaultTerminateAfter_invalid() throws IOException {
assertLiveSettingException(expectedMsg, b -> b.setDefaultTerminateAfter(wrap(-1)));
}

@Test
public void testDefaultTerminateAfterMaxRecallCount_default() throws IOException {
assertEquals(0, getIndexState(getEmptyState()).getDefaultTerminateAfterMaxRecallCount());
}

@Test
public void testDefaultTerminateAfterMaxRecallCount_set() throws IOException {
verifyIntLiveSetting(
100,
ImmutableIndexState::getDefaultTerminateAfterMaxRecallCount,
b -> b.setDefaultTerminateAfterMaxRecallCount(wrap(100)));
}

@Test
public void testDefaultTerminateAfterMaxRecallCount_invalid() throws IOException {
String expectedMsg = "defaultTerminateAfterMaxRecallCount must be >= 0";
assertLiveSettingException(
expectedMsg, b -> b.setDefaultTerminateAfterMaxRecallCount(wrap(-1)));
}

@Test
public void testMaxMergePreCopyDurationSec_default() throws IOException {
assertEquals(0, getIndexState(getEmptyState()).getMaxMergePreCopyDurationSec());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,21 @@ public void testNumHitsToCollect() {

@Test
public void testHasTerminateAfterWrapper() {
SearchRequest request = SearchRequest.newBuilder().setTopHits(10).setTerminateAfter(5).build();
SearchRequest request =
SearchRequest.newBuilder()
.setTopHits(10)
.setTerminateAfter(5)
.setTerminateAfterMaxRecallCount(10)
.build();
TestDocCollector docCollector = new TestDocCollector(request);
assertTrue(docCollector.getManager() instanceof TestDocCollector.TestCollectorManager);
assertTrue(docCollector.getWrappedManager() instanceof TerminateAfterWrapper);
assertEquals(
5, ((TerminateAfterWrapper<?>) docCollector.getWrappedManager()).getTerminateAfter());
assertEquals(
10,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
Expand Down Expand Up @@ -246,6 +255,37 @@ public void testOverrideDefaultTerminateAfter() {
75, ((TerminateAfterWrapper<?>) docCollector.getWrappedManager()).getTerminateAfter());
}

@Test
public void testUsesDefaultTerminateAfterMaxRecallCount() {
IndexState indexState = Mockito.mock(IndexState.class);
when(indexState.getDefaultTerminateAfter()).thenReturn(100);
when(indexState.getDefaultTerminateAfterMaxRecallCount()).thenReturn(1000);

SearchRequest request = SearchRequest.newBuilder().setTopHits(10).build();
TestDocCollector docCollector = new TestDocCollector(request, indexState);
assertEquals(
1000,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
public void testOverrideDefaultTerminateAfterMaxRecallCount() {
IndexState indexState = Mockito.mock(IndexState.class);
when(indexState.getDefaultTerminateAfter()).thenReturn(100);
when(indexState.getDefaultTerminateAfterMaxRecallCount()).thenReturn(1000);

SearchRequest request =
SearchRequest.newBuilder().setTopHits(10).setTerminateAfterMaxRecallCount(75).build();
TestDocCollector docCollector = new TestDocCollector(request, indexState);
assertTrue(docCollector.getManager() instanceof TestDocCollector.TestCollectorManager);
assertTrue(docCollector.getWrappedManager() instanceof TerminateAfterWrapper);
assertEquals(
75,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
public void testWithAllWrappers() {
SearchRequest request =
Expand Down

0 comments on commit 7f1d1b5

Please sign in to comment.