Skip to content

Commit

Permalink
Merge pull request #1144 from cescoffier/fix-blocking-memory-store-in…
Browse files Browse the repository at this point in the history
…-streamed-response

Fix Blocking Memory Store Usage in Streamed Mode
  • Loading branch information
geoand authored Dec 13, 2024
2 parents 1ca2656 + 0690a53 commit 742b816
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1295,13 +1295,15 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
String responseAugmenterClassName = AiServicesMethodBuildItem.gatherResponseAugmenter(method);

// Detect if tools execution may block the caller thread.
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);
boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools,
methodToolClassNames);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames,
switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails, accumulatorClassName,
responseAugmenterClassName);
}

private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package io.quarkiverse.langchain4j.test.streaming;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.Multi;
import io.vertx.core.Context;
import io.vertx.core.Vertx;

public class BlockingMemoryStoreOnStreamedResponseTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(StreamTestUtils.class));

@Inject
MyAiService service;

@RepeatedTest(100) // Verify that the order is preserved.
@ActivateRequestContext
void testFromWorkerThread() {
// We are on a worker thread.
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
// We cannot guarantee the order, as we do not have a context.
assertThat(list).containsExactly("Hi!", " ", "World!");

list = service.hi("123", "Second message").collect().asList().await().indefinitely();
assertThat(list).containsExactly("OK!");
}

@BeforeEach
void cleanup() {
StreamTestUtils.FakeMemoryStore.DC_DATA = null;
}

@RepeatedTest(10)
void testFromDuplicatedContextThread() throws InterruptedException {
Context context = VertxContext.getOrCreateDuplicatedContext(vertx);
CountDownLatch latch = new CountDownLatch(1);
context.executeBlocking(v -> {
try {
Arc.container().requestContext().activate();
var value = UUID.randomUUID().toString();
StreamTestUtils.FakeMemoryStore.DC_DATA = value;
Vertx.currentContext().putLocal("DC_DATA", value);
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
assertThat(list).containsExactly("Hi!", " ", "World!");
Arc.container().requestContext().deactivate();

Arc.container().requestContext().activate();

list = service.hi("123", "Second message").collect().asList().await().indefinitely();
assertThat(list).containsExactly("OK!");
latch.countDown();

} finally {
Arc.container().requestContext().deactivate();
Vertx.currentContext().removeLocal("DC_DATA");

}
}, false);
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
}

@Inject
Vertx vertx;

@Test
void testFromEventLoopThread() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Context context = vertx.getOrCreateContext();
context.runOnContext(v -> {
Arc.container().requestContext().activate();
try {
service.hi("123", "Say hello").collect().asList()
.subscribe().asCompletionStage();
} catch (Exception e) {
assertThat(e)
.isNotNull()
.hasMessageContaining("Expected to be able to block");
} finally {
Arc.container().requestContext().deactivate();
latch.countDown();
}
});
latch.await();
}

@RegisterAiService(streamingChatLanguageModelSupplier = StreamTestUtils.FakeStreamedChatModelSupplier.class, chatMemoryProviderSupplier = StreamTestUtils.FakeMemoryProviderSupplier.class)
public interface MyAiService {

Multi<String> hi(@MemoryId String id, @UserMessage String query);

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package io.quarkiverse.langchain4j.test.streaming;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Supplier;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.quarkus.arc.Arc;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.vertx.core.Vertx;

/**
* Utility class for streaming tests.
*/
public class StreamTestUtils {

public static class FakeMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return new ChatMemoryProvider() {
@Override
public ChatMemory get(Object memoryId) {
return new MessageWindowChatMemory.Builder()
.id(memoryId)
.maxMessages(10)
.chatMemoryStore(new FakeMemoryStore())
.build();
}
};
}
}

public static class FakeStreamedChatModelSupplier implements Supplier<StreamingChatLanguageModel> {

@Override
public StreamingChatLanguageModel get() {
return new FakeStreamedChatModel();
}
}

public static class FakeStreamedChatModel implements StreamingChatLanguageModel {

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
Vertx vertx = Arc.container().select(Vertx.class).get();
var ctxt = vertx.getOrCreateContext();

if (messages.size() > 1) {
var last = (UserMessage) messages.get(messages.size() - 1);
if (last.singleText().equalsIgnoreCase("Second message")) {
if (messages.size() < 3) {
ctxt.runOnContext(x -> handler.onError(new IllegalStateException("Error - no memory")));
return;
} else {
ctxt.runOnContext(x -> {
handler.onNext("OK!");
handler.onComplete(Response.from(AiMessage.from("")));
});
return;
}
}
}

ctxt.runOnContext(x1 -> {
handler.onNext("Hi!");
ctxt.runOnContext(x2 -> {
handler.onNext(" ");
ctxt.runOnContext(x3 -> {
handler.onNext("World!");
ctxt.runOnContext(x -> handler.onComplete(Response.from(AiMessage.from(""))));
});
});
});
}
}

public static class FakeMemoryStore implements ChatMemoryStore {

public static String DC_DATA;

private static final Map<Object, List<ChatMessage>> memories = new ConcurrentHashMap<>();

private void checkDuplicatedContext() {
if (DC_DATA != null) {
if (!DC_DATA.equals(Vertx.currentContext().getLocal("DC_DATA"))) {
throw new AssertionError("Expected to be in the same context");
}
}
}

@Override
public List<ChatMessage> getMessages(Object memoryId) {
if (!Infrastructure.canCallerThreadBeBlocked()) {
throw new AssertionError("Expected to be able to block");
}
checkDuplicatedContext();
return memories.computeIfAbsent(memoryId, x -> new CopyOnWriteArrayList<>());
}

@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
if (!Infrastructure.canCallerThreadBeBlocked()) {
throw new AssertionError("Expected to be able to block");
}
memories.put(memoryId, messages);
}

@Override
public void deleteMessages(Object memoryId) {
if (!Infrastructure.canCallerThreadBeBlocked()) {
throw new AssertionError("Expected to be able to block");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public final class AiServiceMethodCreateInfo {
private OutputTokenAccumulator accumulator;

private final LazyValue<Integer> guardrailsMaxRetry;
private final boolean switchToWorkerThread;
private final boolean switchToWorkerThreadForToolExecution;

@RecordableConstructor
public AiServiceMethodCreateInfo(String interfaceName, String methodName,
Expand All @@ -74,7 +74,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName,
Optional<SpanInfo> spanInfo,
ResponseSchemaInfo responseSchemaInfo,
List<String> toolClassNames,
boolean switchToWorkerThread,
boolean switchToWorkerThreadForToolExecution,
List<String> inputGuardrailsClassNames,
List<String> outputGuardrailsClassNames,
String outputTokenAccumulatorClassName,
Expand Down Expand Up @@ -108,7 +108,7 @@ public Integer get() {
.orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
}
});
this.switchToWorkerThread = switchToWorkerThread;
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
this.responseAugmenterClassName = responseAugmenterClassName;
}

Expand Down Expand Up @@ -238,8 +238,8 @@ public String getUserMessageTemplate() {
return userMessageTemplateOpt.orElse("");
}

public boolean isSwitchToWorkerThread() {
return switchToWorkerThread;
public boolean isSwitchToWorkerThreadForToolExecution() {
return switchToWorkerThreadForToolExecution;
}

public void setResponseAugmenter(Class<? extends AiResponseAugmenter<?>> augmenter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ public Object implement(Input input) {

private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs,
QuarkusAiServiceContext context, Audit audit) {
boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread();
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
Optional<SystemMessage> systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs,
context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList());
Expand Down Expand Up @@ -227,7 +228,7 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
finalToolExecutors, ar.contents(), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread());
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
memory, ar, methodCreateInfo.getUserMessageTemplate(),
Expand Down Expand Up @@ -278,7 +279,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread());
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
new ResponseAugmenterParams(actualUserMessage,
chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(),
Expand All @@ -287,7 +288,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,

return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread())
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread)
.plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo))
.map(chunk -> {
OutputGuardrailResult result;
Expand Down Expand Up @@ -785,19 +786,22 @@ private static class TokenStreamMulti extends AbstractMulti<String> implements M
private final List<Content> contents;
private final QuarkusAiServiceContext context;
private final Object memoryId;
private final boolean mustSwitchToWorkerThread;
private final boolean switchToWorkerThreadForToolExecution;
private final boolean isCallerRunningOnWorkerThread;

public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications,
Map<String, ToolExecutor> toolExecutors,
List<Content> contents, QuarkusAiServiceContext context, Object memoryId, boolean mustSwitchToWorkerThread) {
List<Content> contents, QuarkusAiServiceContext context, Object memoryId,
boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) {
// We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription.
this.messagesToSend = messagesToSend;
this.toolSpecifications = toolSpecifications;
this.toolsExecutors = toolExecutors;
this.contents = contents;
this.context = context;
this.memoryId = memoryId;
this.mustSwitchToWorkerThread = mustSwitchToWorkerThread;
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
this.isCallerRunningOnWorkerThread = isCallerRunningOnWorkerThread;
}

@Override
Expand All @@ -810,19 +814,20 @@ public void subscribe(MultiSubscriber<? super String> subscriber) {

private void createTokenStream(UnicastProcessor<String> processor) {
Context ctxt = null;
if (mustSwitchToWorkerThread) {
if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) {
// we create or retrieve the current context, to use `executeBlocking` when required.
ctxt = VertxContext.getOrCreateDuplicatedContext();
}

var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications,
toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread);
toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution,
isCallerRunningOnWorkerThread);
TokenStream tokenStream = stream
.onNext(processor::onNext)
.onComplete(message -> processor.onComplete())
.onError(processor::onError);
// This is equivalent to "run subscription on worker thread"
if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) {
if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) {
ctxt.executeBlocking(new Callable<Void>() {
@Override
public Void call() {
Expand Down
Loading

0 comments on commit 742b816

Please sign in to comment.