Skip to content

Commit

Permalink
PR changes
Browse files Browse the repository at this point in the history
  • Loading branch information
swethakann committed Nov 26, 2024
1 parent c340234 commit 7c39462
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ public int getTerminateAfter() {
return terminateAfter;
}

/** Max documents to count beyond terminateAfter. */
public int getTerminateAfterMaxRecallCount() {
return terminateAfterMaxRecallCount;
}

/**
* {@link Collector} implementation that wraps another collector and terminates collection after a
* certain global count of documents is reached.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ <C extends Collector> CollectorManager<? extends Collector, SearcherResult> wrap
int terminateAfterMaxRecallCount =
request.getTerminateAfterMaxRecallCount() > 0
? request.getTerminateAfterMaxRecallCount()
: 0;
: indexState.getDefaultTerminateAfterMaxRecallCount();
if (terminateAfter > 0) {
wrapped =
new TerminateAfterWrapper<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,21 @@ public void testNumHitsToCollect() {

@Test
public void testHasTerminateAfterWrapper() {
SearchRequest request = SearchRequest.newBuilder().setTopHits(10).setTerminateAfter(5).build();
SearchRequest request =
SearchRequest.newBuilder()
.setTopHits(10)
.setTerminateAfter(5)
.setTerminateAfterMaxRecallCount(10)
.build();
TestDocCollector docCollector = new TestDocCollector(request);
assertTrue(docCollector.getManager() instanceof TestDocCollector.TestCollectorManager);
assertTrue(docCollector.getWrappedManager() instanceof TerminateAfterWrapper);
assertEquals(
5, ((TerminateAfterWrapper<?>) docCollector.getWrappedManager()).getTerminateAfter());
assertEquals(
10,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
Expand Down Expand Up @@ -246,6 +255,37 @@ public void testOverrideDefaultTerminateAfter() {
75, ((TerminateAfterWrapper<?>) docCollector.getWrappedManager()).getTerminateAfter());
}

@Test
public void testUsesDefaultTerminateAfterMaxRecallCount() {
IndexState indexState = Mockito.mock(IndexState.class);
when(indexState.getDefaultTerminateAfter()).thenReturn(100);
when(indexState.getDefaultTerminateAfterMaxRecallCount()).thenReturn(1000);

SearchRequest request = SearchRequest.newBuilder().setTopHits(10).build();
TestDocCollector docCollector = new TestDocCollector(request, indexState);
assertEquals(
1000,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
public void testOverrideDefaultTerminateAfterMaxRecallCount() {
IndexState indexState = Mockito.mock(IndexState.class);
when(indexState.getDefaultTerminateAfter()).thenReturn(100);
when(indexState.getDefaultTerminateAfterMaxRecallCount()).thenReturn(1000);

SearchRequest request =
SearchRequest.newBuilder().setTopHits(10).setTerminateAfterMaxRecallCount(75).build();
TestDocCollector docCollector = new TestDocCollector(request, indexState);
assertTrue(docCollector.getManager() instanceof TestDocCollector.TestCollectorManager);
assertTrue(docCollector.getWrappedManager() instanceof TerminateAfterWrapper);
assertEquals(
75,
((TerminateAfterWrapper<?>) docCollector.getWrappedManager())
.getTerminateAfterMaxRecallCount());
}

@Test
public void testWithAllWrappers() {
SearchRequest request =
Expand Down

0 comments on commit 7c39462

Please sign in to comment.