Skip to content

Commit

Permalink
Fix Blocking Memory Store Usage in Streamed Mode
Browse files Browse the repository at this point in the history
This commit addresses issues with using the blocking memory store in streamed responses.

* Ensures the execution captures whether the caller is running on a worker thread.
* Switches to worker threads for every emission and completion event when the caller is using a worker thread.
* Relies on executeBlocking to propagate the context automatically when possible.

Note:
* The blocking memory store cannot be used when invoked on the event loop. It now requires that the caller must be on a worker thread.
  • Loading branch information
cescoffier committed Dec 13, 2024
1 parent 5ec8e05 commit 0690a53
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1284,13 +1284,15 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
String responseAugmenterClassName = AiServicesMethodBuildItem.gatherResponseAugmenter(method);

// Detect if tools execution may block the caller thread.
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);
boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools,
methodToolClassNames);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames,
switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails, accumulatorClassName,
responseAugmenterClassName);
}

private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package io.quarkiverse.langchain4j.test.streaming;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.Multi;
import io.vertx.core.Context;
import io.vertx.core.Vertx;

public class BlockingMemoryStoreOnStreamedResponseTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(StreamTestUtils.class));

@Inject
MyAiService service;

@RepeatedTest(100) // Verify that the order is preserved.
@ActivateRequestContext
void testFromWorkerThread() {
// We are on a worker thread.
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
// We cannot guarantee the order, as we do not have a context.
assertThat(list).containsExactly("Hi!", " ", "World!");

list = service.hi("123", "Second message").collect().asList().await().indefinitely();
assertThat(list).containsExactly("OK!");
}

@BeforeEach
void cleanup() {
StreamTestUtils.FakeMemoryStore.DC_DATA = null;
}

@RepeatedTest(10)
void testFromDuplicatedContextThread() throws InterruptedException {
Context context = VertxContext.getOrCreateDuplicatedContext(vertx);
CountDownLatch latch = new CountDownLatch(1);
context.executeBlocking(v -> {
try {
Arc.container().requestContext().activate();
var value = UUID.randomUUID().toString();
StreamTestUtils.FakeMemoryStore.DC_DATA = value;
Vertx.currentContext().putLocal("DC_DATA", value);
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
assertThat(list).containsExactly("Hi!", " ", "World!");
Arc.container().requestContext().deactivate();

Arc.container().requestContext().activate();

list = service.hi("123", "Second message").collect().asList().await().indefinitely();
assertThat(list).containsExactly("OK!");
latch.countDown();

} finally {
Arc.container().requestContext().deactivate();
Vertx.currentContext().removeLocal("DC_DATA");

}
}, false);
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
}

@Inject
Vertx vertx;

@Test
void testFromEventLoopThread() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Context context = vertx.getOrCreateContext();
context.runOnContext(v -> {
Arc.container().requestContext().activate();
try {
service.hi("123", "Say hello").collect().asList()
.subscribe().asCompletionStage();
} catch (Exception e) {
assertThat(e)
.isNotNull()
.hasMessageContaining("Expected to be able to block");
} finally {
Arc.container().requestContext().deactivate();
latch.countDown();
}
});
latch.await();
}

@RegisterAiService(streamingChatLanguageModelSupplier = StreamTestUtils.FakeStreamedChatModelSupplier.class, chatMemoryProviderSupplier = StreamTestUtils.FakeMemoryProviderSupplier.class)
public interface MyAiService {

Multi<String> hi(@MemoryId String id, @UserMessage String query);

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package io.quarkiverse.langchain4j.test.streaming;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Supplier;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.quarkus.arc.Arc;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.vertx.core.Vertx;

/**
* Utility class for streaming tests.
*/
public class StreamTestUtils {

public static class FakeMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return new ChatMemoryProvider() {
@Override
public ChatMemory get(Object memoryId) {
return new MessageWindowChatMemory.Builder()
.id(memoryId)
.maxMessages(10)
.chatMemoryStore(new FakeMemoryStore())
.build();
}
};
}
}

public static class FakeStreamedChatModelSupplier implements Supplier<StreamingChatLanguageModel> {

@Override
public StreamingChatLanguageModel get() {
return new FakeStreamedChatModel();
}
}

public static class FakeStreamedChatModel implements StreamingChatLanguageModel {

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
Vertx vertx = Arc.container().select(Vertx.class).get();
var ctxt = vertx.getOrCreateContext();

if (messages.size() > 1) {
var last = (UserMessage) messages.get(messages.size() - 1);
if (last.singleText().equalsIgnoreCase("Second message")) {
if (messages.size() < 3) {
ctxt.runOnContext(x -> handler.onError(new IllegalStateException("Error - no memory")));
return;
} else {
ctxt.runOnContext(x -> {
handler.onNext("OK!");
handler.onComplete(Response.from(AiMessage.from("")));
});
return;
}
}
}

ctxt.runOnContext(x1 -> {
handler.onNext("Hi!");
ctxt.runOnContext(x2 -> {
handler.onNext(" ");
ctxt.runOnContext(x3 -> {
handler.onNext("World!");
ctxt.runOnContext(x -> handler.onComplete(Response.from(AiMessage.from(""))));
});
});
});
}
}

public static class FakeMemoryStore implements ChatMemoryStore {

public static String DC_DATA;

private static final Map<Object, List<ChatMessage>> memories = new ConcurrentHashMap<>();

private void checkDuplicatedContext() {
if (DC_DATA != null) {
if (!DC_DATA.equals(Vertx.currentContext().getLocal("DC_DATA"))) {
throw new AssertionError("Expected to be in the same context");
}
}
}

@Override
public List<ChatMessage> getMessages(Object memoryId) {
if (!Infrastructure.canCallerThreadBeBlocked()) {
throw new AssertionError("Expected to be able to block");
}
checkDuplicatedContext();
return memories.computeIfAbsent(memoryId, x -> new CopyOnWriteArrayList<>());
}

@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
if (!Infrastructure.canCallerThreadBeBlocked()) {
throw new AssertionError("Expected to be able to block");
}
memories.put(memoryId, messages);
}

@Override
public void deleteMessages(Object memoryId) {
if (!Infrastructure.canCallerThreadBeBlocked()) {
throw new AssertionError("Expected to be able to block");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public final class AiServiceMethodCreateInfo {
private OutputTokenAccumulator accumulator;

private final LazyValue<Integer> guardrailsMaxRetry;
private final boolean switchToWorkerThread;
private final boolean switchToWorkerThreadForToolExecution;

@RecordableConstructor
public AiServiceMethodCreateInfo(String interfaceName, String methodName,
Expand All @@ -73,7 +73,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName,
Optional<SpanInfo> spanInfo,
ResponseSchemaInfo responseSchemaInfo,
List<String> toolClassNames,
boolean switchToWorkerThread,
boolean switchToWorkerThreadForToolExecution,
List<String> inputGuardrailsClassNames,
List<String> outputGuardrailsClassNames,
String outputTokenAccumulatorClassName,
Expand Down Expand Up @@ -107,7 +107,7 @@ public Integer get() {
.orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
}
});
this.switchToWorkerThread = switchToWorkerThread;
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
this.responseAugmenterClassName = responseAugmenterClassName;
}

Expand Down Expand Up @@ -237,8 +237,8 @@ public String getUserMessageTemplate() {
return userMessageTemplateOpt.orElse("");
}

public boolean isSwitchToWorkerThread() {
return switchToWorkerThread;
public boolean isSwitchToWorkerThreadForToolExecution() {
return switchToWorkerThreadForToolExecution;
}

public void setResponseAugmenter(Class<? extends AiResponseAugmenter<?>> augmenter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ public Object implement(Input input) {

private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs,
QuarkusAiServiceContext context, Audit audit) {
boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread();
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
Optional<SystemMessage> systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs,
context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList());
Expand Down Expand Up @@ -217,7 +218,7 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
finalToolExecutors, ar.contents(), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread());
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
memory, ar, methodCreateInfo.getUserMessageTemplate(),
Expand Down Expand Up @@ -268,7 +269,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread());
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
new ResponseAugmenterParams(actualUserMessage,
chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(),
Expand All @@ -277,7 +278,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,

return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread())
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread)
.plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo))
.map(chunk -> {
OutputGuardrailResult result;
Expand Down Expand Up @@ -786,19 +787,22 @@ private static class TokenStreamMulti extends AbstractMulti<String> implements M
private final List<Content> contents;
private final QuarkusAiServiceContext context;
private final Object memoryId;
private final boolean mustSwitchToWorkerThread;
private final boolean switchToWorkerThreadForToolExecution;
private final boolean isCallerRunningOnWorkerThread;

public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications,
Map<String, ToolExecutor> toolExecutors,
List<Content> contents, QuarkusAiServiceContext context, Object memoryId, boolean mustSwitchToWorkerThread) {
List<Content> contents, QuarkusAiServiceContext context, Object memoryId,
boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) {
// We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription.
this.messagesToSend = messagesToSend;
this.toolSpecifications = toolSpecifications;
this.toolsExecutors = toolExecutors;
this.contents = contents;
this.context = context;
this.memoryId = memoryId;
this.mustSwitchToWorkerThread = mustSwitchToWorkerThread;
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
this.isCallerRunningOnWorkerThread = isCallerRunningOnWorkerThread;
}

@Override
Expand All @@ -811,19 +815,20 @@ public void subscribe(MultiSubscriber<? super String> subscriber) {

private void createTokenStream(UnicastProcessor<String> processor) {
Context ctxt = null;
if (mustSwitchToWorkerThread) {
if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) {
// we create or retrieve the current context, to use `executeBlocking` when required.
ctxt = VertxContext.getOrCreateDuplicatedContext();
}

var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications,
toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread);
toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution,
isCallerRunningOnWorkerThread);
TokenStream tokenStream = stream
.onNext(processor::onNext)
.onComplete(message -> processor.onComplete())
.onError(processor::onError);
// This is equivalent to "run subscription on worker thread"
if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) {
if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) {
ctxt.executeBlocking(new Callable<Void>() {
@Override
public Void call() {
Expand Down
Loading

0 comments on commit 0690a53

Please sign in to comment.