Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default terminate after max recall count #790

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clientlib/src/main/proto/yelp/nrtsearch/luceneserver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ message LiveSettingsRequest {
int32 defaultSearchTimeoutCheckEvery = 13;
//Terminate after value to use when not specified in the search request.
int32 defaultTerminateAfter = 14;
//Terminate after max recall count value to use when not specified in the search request.
int32 defaultTerminateAfterMaxRecallCount = 15;
}

/* Response from Server to liveSettings */
Expand Down Expand Up @@ -1157,6 +1159,8 @@ message IndexLiveSettings {
google.protobuf.BoolValue parallelFetchByField = 16;
// The number of documents/fields per parallel fetch task, default: 50
google.protobuf.Int32Value parallelFetchChunkSize = 17;
// Terminate after max recall count value to use when not specified in the search request, or 0 for none, default: 0
google.protobuf.Int32Value defaultTerminateAfterMaxRecallCount = 18;
}

message IndexStateInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ private LiveSettingsResponse handleAsLiveSettingsV2(
.setValue(liveSettingsRequest.getDefaultTerminateAfter())
.build());
}
if (liveSettingsRequest.getDefaultTerminateAfterMaxRecallCount() >= 0) {
settingsBuilder.setDefaultTerminateAfterMaxRecallCount(
Int32Value.newBuilder()
.setValue(liveSettingsRequest.getDefaultTerminateAfterMaxRecallCount())
.build());
}
try {
updatedSettings = indexStateManager.updateLiveSettings(settingsBuilder.build(), false);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ public class ImmutableIndexState extends IndexState {
.setDefaultSearchTimeoutSec(DoubleValue.newBuilder().setValue(0).build())
.setDefaultSearchTimeoutCheckEvery(Int32Value.newBuilder().setValue(0).build())
.setDefaultTerminateAfter(Int32Value.newBuilder().setValue(0).build())
.setDefaultTerminateAfterMaxRecallCount(Int32Value.newBuilder().setValue(0).build())
.setMaxMergePreCopyDurationSec(UInt64Value.newBuilder().setValue(0))
.setVerboseMetrics(BoolValue.newBuilder().setValue(false).build())
.setParallelFetchByField(BoolValue.newBuilder().setValue(false).build())
Expand All @@ -167,6 +168,7 @@ public class ImmutableIndexState extends IndexState {
private final double defaultSearchTimeoutSec;
private final int defaultSearchTimeoutCheckEvery;
private final int defaultTerminateAfter;
private final int defaultTerminateAfterMaxRecallCount;
private final long maxMergePreCopyDurationSec;
private final boolean verboseMetrics;
private final ParallelFetchConfig parallelFetchConfig;
Expand Down Expand Up @@ -262,6 +264,8 @@ public ImmutableIndexState(
defaultSearchTimeoutCheckEvery =
mergedLiveSettingsWithLocal.getDefaultSearchTimeoutCheckEvery().getValue();
defaultTerminateAfter = mergedLiveSettingsWithLocal.getDefaultTerminateAfter().getValue();
defaultTerminateAfterMaxRecallCount =
mergedLiveSettingsWithLocal.getDefaultTerminateAfterMaxRecallCount().getValue();
maxMergePreCopyDurationSec =
mergedLiveSettingsWithLocal.getMaxMergePreCopyDurationSec().getValue();
verboseMetrics = mergedLiveSettingsWithLocal.getVerboseMetrics().getValue();
Expand Down Expand Up @@ -718,6 +722,11 @@ public int getDefaultTerminateAfter() {
return defaultTerminateAfter;
}

@Override
public int getDefaultTerminateAfterMaxRecallCount() {
return defaultTerminateAfterMaxRecallCount;
}

@Override
public int getDefaultSearchTimeoutCheckEvery() {
return defaultSearchTimeoutCheckEvery;
Expand Down Expand Up @@ -809,6 +818,9 @@ static void validateLiveSettings(IndexLiveSettings liveSettings) {
if (liveSettings.getDefaultTerminateAfter().getValue() < 0) {
throw new IllegalArgumentException("defaultTerminateAfter must be >= 0");
}
if (liveSettings.getDefaultTerminateAfterMaxRecallCount().getValue() < 0) {
throw new IllegalArgumentException("defaultTerminateAfterMaxRecallCount must be >= 0");
}
if (liveSettings.getMaxMergePreCopyDurationSec().getValue() < 0) {
throw new IllegalArgumentException("maxMergePreCopyDurationSec must be >= 0");
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/com/yelp/nrtsearch/server/index/IndexState.java
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ public abstract IndexWriterConfig getIndexWriterConfig(
/** Get the default terminate after. */
public abstract int getDefaultTerminateAfter();

/** Get the default terminate after max recall count. */
public abstract int getDefaultTerminateAfterMaxRecallCount();

/** Get the default search timeout check every. */
public abstract int getDefaultSearchTimeoutCheckEvery();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public int getTerminateAfter() {
return terminateAfter;
}

/** Max documents to count beyond terminateAfter. */
public int getTerminateAfterMaxRecallCount() {
return terminateAfterMaxRecallCount;
}

/**
* {@link Collector} implementation that wraps another collector and terminates collection after a
* certain global count of documents is reached.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ <C extends Collector> CollectorManager<? extends Collector, SearcherResult> wrap
int terminateAfterMaxRecallCount =
request.getTerminateAfterMaxRecallCount() > 0
? request.getTerminateAfterMaxRecallCount()
: 0;
: indexState.getDefaultTerminateAfterMaxRecallCount();
if (terminateAfter > 0) {
wrapped =
new TerminateAfterWrapper<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ public class LiveSettingsV2Command implements Callable<Integer> {
description = "Terminate after to use when not provided by the request")
private Integer defaultTerminateAfter;

@CommandLine.Option(
names = {"--defaultTerminateAfterMaxRecallCount"},
description = "Terminate after max recall count to use when not provided by the request")
private Integer defaultTerminateAfterMaxRecallCount;

@CommandLine.Option(
names = {"--maxMergePreCopyDurationSec"},
description = "Maximum time allowed for merge precopy in seconds")
Expand Down Expand Up @@ -192,6 +197,10 @@ public Integer call() throws Exception {
liveSettingsBuilder.setDefaultTerminateAfter(
Int32Value.newBuilder().setValue(defaultTerminateAfter).build());
}
if (defaultTerminateAfterMaxRecallCount != null) {
liveSettingsBuilder.setDefaultTerminateAfterMaxRecallCount(
Int32Value.newBuilder().setValue(defaultTerminateAfterMaxRecallCount).build());
}
if (maxMergePreCopyDurationSec != null) {
liveSettingsBuilder.setMaxMergePreCopyDurationSec(
UInt64Value.newBuilder().setValue(maxMergePreCopyDurationSec));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,8 @@ public void testSetIndexLiveSettings() throws IOException {
IndexLiveSettings.newBuilder()
.setDefaultTerminateAfter(
Int32Value.newBuilder().setValue(1000).build())
.setDefaultTerminateAfterMaxRecallCount(
Int32Value.newBuilder().setValue(1000).build())
.setSegmentsPerTier(Int32Value.newBuilder().setValue(4).build())
.setSliceMaxSegments(Int32Value.newBuilder().setValue(50).build())
.setDefaultSearchTimeoutSec(
Expand All @@ -784,6 +786,7 @@ public void testSetIndexLiveSettings() throws IOException {
IndexLiveSettings expectedSettings =
ImmutableIndexState.DEFAULT_INDEX_LIVE_SETTINGS.toBuilder()
.setDefaultTerminateAfter(Int32Value.newBuilder().setValue(1000).build())
.setDefaultTerminateAfterMaxRecallCount(Int32Value.newBuilder().setValue(1000).build())
.setSegmentsPerTier(Int32Value.newBuilder().setValue(4).build())
.setSliceMaxSegments(Int32Value.newBuilder().setValue(50).build())
.setDefaultSearchTimeoutSec(DoubleValue.newBuilder().setValue(5.1).build())
Expand Down Expand Up @@ -1647,6 +1650,7 @@ public void testLiveSettingsV1All() throws IOException {
.setDefaultSearchTimeoutSec(13.0)
.setDefaultSearchTimeoutCheckEvery(500)
.setDefaultTerminateAfter(5000)
.setDefaultTerminateAfterMaxRecallCount(6000)
.build();

LiveSettingsResponse response = primaryClient.getBlockingStub().liveSettings(request);
Expand All @@ -1665,6 +1669,7 @@ public void testLiveSettingsV1All() throws IOException {
.setDefaultSearchTimeoutSec(DoubleValue.newBuilder().setValue(13.0).build())
.setDefaultSearchTimeoutCheckEvery(Int32Value.newBuilder().setValue(500).build())
.setDefaultTerminateAfter(Int32Value.newBuilder().setValue(5000).build())
.setDefaultTerminateAfterMaxRecallCount(Int32Value.newBuilder().setValue(6000).build())
.setMaxMergePreCopyDurationSec(UInt64Value.newBuilder().setValue(0))
.setVerboseMetrics(BoolValue.newBuilder().setValue(false).build())
.setParallelFetchByField(BoolValue.newBuilder().setValue(false).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,26 @@ public void testDefaultTerminateAfter_invalid() throws IOException {
assertLiveSettingException(expectedMsg, b -> b.setDefaultTerminateAfter(wrap(-1)));
}

@Test
public void testDefaultTerminateAfterMaxRecallCount_default() throws IOException {
assertEquals(0, getIndexState(getEmptyState()).getDefaultTerminateAfterMaxRecallCount());
}

@Test
public void testDefaultTerminateAfterMaxRecallCount_set() throws IOException {
verifyIntLiveSetting(
100,
ImmutableIndexState::getDefaultTerminateAfterMaxRecallCount,
b -> b.setDefaultTerminateAfterMaxRecallCount(wrap(100)));
}

@Test
public void testDefaultTerminateAfterMaxRecallCount_invalid() throws IOException {
String expectedMsg = "defaultTerminateAfterMaxRecallCount must be >= 0";
assertLiveSettingException(
expectedMsg, b -> b.setDefaultTerminateAfterMaxRecallCount(wrap(-1)));
}

@Test
public void testMaxMergePreCopyDurationSec_default() throws IOException {
assertEquals(0, getIndexState(getEmptyState()).getMaxMergePreCopyDurationSec());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,21 @@ public void testNumHitsToCollect() {

@Test
public void testHasTerminateAfterWrapper() {
SearchRequest request = SearchRequest.newBuilder().setTopHits(10).setTerminateAfter(5).build();
SearchRequest request =
SearchRequest.newBuilder()
.setTopHits(10)
.setTerminateAfter(5)
.setTerminateAfterMaxRecallCount(10)
.build();
TestDocCollector docCollector = new TestDocCollector(request);
assertTrue(docCollector.getManager() instanceof TestDocCollector.TestCollectorManager);
assertTrue(docCollector.getWrappedManager() instanceof TerminateAfterWrapper);
assertEquals(
5, ((TerminateAfterWrapper<?>) docCollector.getWrappedManager()).getTerminateAfter());
assertEquals(
10,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
Expand Down Expand Up @@ -246,6 +255,37 @@ public void testOverrideDefaultTerminateAfter() {
75, ((TerminateAfterWrapper<?>) docCollector.getWrappedManager()).getTerminateAfter());
}

@Test
public void testUsesDefaultTerminateAfterMaxRecallCount() {
IndexState indexState = Mockito.mock(IndexState.class);
when(indexState.getDefaultTerminateAfter()).thenReturn(100);
when(indexState.getDefaultTerminateAfterMaxRecallCount()).thenReturn(1000);

SearchRequest request = SearchRequest.newBuilder().setTopHits(10).build();
TestDocCollector docCollector = new TestDocCollector(request, indexState);
assertEquals(
1000,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
public void testOverrideDefaultTerminateAfterMaxRecallCount() {
IndexState indexState = Mockito.mock(IndexState.class);
when(indexState.getDefaultTerminateAfter()).thenReturn(100);
when(indexState.getDefaultTerminateAfterMaxRecallCount()).thenReturn(1000);

SearchRequest request =
SearchRequest.newBuilder().setTopHits(10).setTerminateAfterMaxRecallCount(75).build();
TestDocCollector docCollector = new TestDocCollector(request, indexState);
assertTrue(docCollector.getManager() instanceof TestDocCollector.TestCollectorManager);
assertTrue(docCollector.getWrappedManager() instanceof TerminateAfterWrapper);
assertEquals(
75,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
public void testWithAllWrappers() {
SearchRequest request =
Expand Down
Loading