Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,12 @@ public Completable compact(Session session, BaseSessionService sessionService) {
logger.debug("Running tail retention event compaction for session {}", session.id());

return Maybe.just(session.events())
.filter(this::shouldCompact)
.flatMap(events -> getCompactionEvents(events))
.flatMap(this::getCompactionEvents)
.flatMap(summarizer::summarizeEvents)
.flatMapSingle(e -> sessionService.appendEvent(session, e))
.ignoreElement();
}

private boolean shouldCompact(List<Event> events) {
int count = getLatestPromptTokenCount(events).orElse(0);

// TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not
// available.
if (count <= tokenThreshold) {
logger.debug(
"Skipping compaction. Prompt token count {} is within threshold {}",
count,
tokenThreshold);
return false;
}
return true;
}

private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
return Lists.reverse(events).stream()
.map(Event::usageMetadata)
.flatMap(Optional::stream)
.map(GenerateContentResponseUsageMetadata::promptTokenCount)
.flatMap(Optional::stream)
.findFirst();
}

/**
* Identifies events to be compacted based on the tail retention strategy.
*
Expand Down Expand Up @@ -161,8 +136,19 @@ private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
* together. The new compaction event will cover the range from the start of the included
* compaction event (C2, T=1) to the end of the new events (E4, T=4).
* </ol>
*
* @param events The list of events to process.
*/
private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
Optional<Integer> count = getLatestPromptTokenCount(events);
if (count.isPresent() && count.get() <= tokenThreshold) {
logger.debug(
"Skipping compaction. Prompt token count {} is within threshold {}",
count.get(),
tokenThreshold);
return Maybe.empty();
}

long compactionEndTimestamp = Long.MIN_VALUE;
Event lastCompactionEvent = null;
List<Event> eventsToSummarize = new ArrayList<>();
Expand Down Expand Up @@ -195,11 +181,6 @@ private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
}
}

// If there are not enough events to summarize, we can return early.
if (eventsToSummarize.size() <= retentionSize) {
return Maybe.empty();
}

// Add the last compaction event to the list of events to summarize.
// This is to ensure that the last compaction event is included in the summary.
if (lastCompactionEvent != null) {
Expand All @@ -214,6 +195,22 @@ private Maybe<List<Event>> getCompactionEvents(List<Event> events) {

Collections.reverse(eventsToSummarize);

if (count.isEmpty()) {
int estimatedCount = estimateTokenCount(eventsToSummarize);
if (estimatedCount <= tokenThreshold) {
logger.debug(
"Skipping compaction. Estimated prompt token count {} is within threshold {}",
estimatedCount,
tokenThreshold);
return Maybe.empty();
}
}

// If there are not enough events to summarize, we can return early.
if (eventsToSummarize.size() <= retentionSize) {
return Maybe.empty();
}

// Apply retention: keep the most recent 'retentionSize' events out of the summary.
// We do this by removing them from the list of events to be summarized.
eventsToSummarize
Expand All @@ -222,6 +219,22 @@ private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
return Maybe.just(eventsToSummarize);
}

private int estimateTokenCount(List<Event> events) {
// A common rule of thumb is that one token roughly corresponds to 4 characters of text for
// common English text.
// See https://platform.openai.com/tokenizer
return events.stream().mapToInt(event -> event.stringifyContent().length()).sum() / 4;
}

private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
return Lists.reverse(events).stream()
.map(Event::usageMetadata)
.flatMap(Optional::stream)
.map(GenerateContentResponseUsageMetadata::promptTokenCount)
.flatMap(Optional::stream)
.findFirst();
}

private static boolean isCompactEvent(Event event) {
return event.actions() != null && event.actions().compaction().isPresent();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ public void constructor_negativeRetentionSize_throwsException() {
}

@Test
// TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is
// not available.
public void compaction_skippedWhenTokenUsageMissing() {
public void compaction_skippedWhenEstimatedTokenUsageBelowThreshold() {
// Threshold is 100.
// Event1: "Event1" -> length 6.
// Retain1: "Retain1" -> length 7.
// Retain2: "Retain2" -> length 7.
// Total length = 20. Estimated tokens = 20 / 4 = 5.
// 5 <= 100 -> Skip.
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100);
ImmutableList<Event> events =
ImmutableList.of(
Expand All @@ -92,6 +96,34 @@ public void compaction_skippedWhenTokenUsageMissing() {
verify(mockSessionService, never()).appendEvent(any(), any());
}

@Test
public void compaction_happensWhenEstimatedTokenUsageAboveThreshold() {
// Threshold is 2.
// Event1: "Event1" -> length 6.
// Retain1: "Retain1" -> length 7.
// Retain2: "Retain2" -> length 7.
// Total eligible for estimation (including retained ones as per current logic):
// Logic: getCompactionEvents returns [Event1, Retain1, Retain2] for estimation.
// Total length = 20. Estimated tokens = 20 / 4 = 5.
// 5 > 2 -> Compact.
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 2);
ImmutableList<Event> events =
ImmutableList.of(
createEvent(1, "Event1"),
createEvent(2, "Retain1"),
createEvent(3, "Retain2")); // No usage metadata
Session session = Session.builder("id").events(events).build();
Event summaryEvent = createEvent(4, "Summary");

when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent));
when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent));

compactor.compact(session, mockSessionService).blockingSubscribe();

verify(mockSummarizer).summarizeEvents(any());
verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent));
}

@Test
public void compaction_skippedWhenTokenUsageBelowThreshold() {
// Threshold is 300, usage is 200.
Expand Down