diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index de5f7ba46..b8fed660a 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -1295,13 +1295,15 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( String responseAugmenterClassName = AiServicesMethodBuildItem.gatherResponseAugmenter(method); // Detect if tools execution may block the caller thread. - boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames); + boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools, + methodToolClassNames); return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration, returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)), - metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread, - inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName); + metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, + switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails, accumulatorClassName, + responseAugmenterClassName); } private Optional jsonSchemaFrom(java.lang.reflect.Type returnType) { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/BlockingMemoryStoreOnStreamedResponseTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/BlockingMemoryStoreOnStreamedResponseTest.java new file mode 100644 index 000000000..778cf6b4a --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/BlockingMemoryStoreOnStreamedResponseTest.java @@ -0,0 +1,117 @@ +package io.quarkiverse.langchain4j.test.streaming; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Context; +import io.vertx.core.Vertx; + +public class BlockingMemoryStoreOnStreamedResponseTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(StreamTestUtils.class)); + + @Inject + MyAiService service; + + @RepeatedTest(100) // Verify that the order is preserved. + @ActivateRequestContext + void testFromWorkerThread() { + // We are on a worker thread. + List list = service.hi("123", "Say hello").collect().asList().await().indefinitely(); + // We cannot guarantee the order, as we do not have a context. + assertThat(list).containsExactly("Hi!", " ", "World!"); + + list = service.hi("123", "Second message").collect().asList().await().indefinitely(); + assertThat(list).containsExactly("OK!"); + } + + @BeforeEach + void cleanup() { + StreamTestUtils.FakeMemoryStore.DC_DATA = null; + } + + @RepeatedTest(10) + void testFromDuplicatedContextThread() throws InterruptedException { + Context context = VertxContext.getOrCreateDuplicatedContext(vertx); + CountDownLatch latch = new CountDownLatch(1); + context.executeBlocking(v -> { + try { + Arc.container().requestContext().activate(); + var value = UUID.randomUUID().toString(); + StreamTestUtils.FakeMemoryStore.DC_DATA = value; + Vertx.currentContext().putLocal("DC_DATA", value); + List list = service.hi("123", "Say hello").collect().asList().await().indefinitely(); + assertThat(list).containsExactly("Hi!", " ", "World!"); + Arc.container().requestContext().deactivate(); + + Arc.container().requestContext().activate(); + + list = service.hi("123", "Second message").collect().asList().await().indefinitely(); + assertThat(list).containsExactly("OK!"); + latch.countDown(); + + } finally { + Arc.container().requestContext().deactivate(); + Vertx.currentContext().removeLocal("DC_DATA"); + + } + }, false); + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + } + + @Inject + Vertx vertx; + + @Test + void testFromEventLoopThread() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + Context context = vertx.getOrCreateContext(); + context.runOnContext(v -> { + Arc.container().requestContext().activate(); + try { + service.hi("123", "Say hello").collect().asList() + .subscribe().asCompletionStage(); + } catch (Exception e) { + assertThat(e) + .isNotNull() + .hasMessageContaining("Expected to be able to block"); + } finally { + Arc.container().requestContext().deactivate(); + latch.countDown(); + } + }); + latch.await(); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = StreamTestUtils.FakeStreamedChatModelSupplier.class, chatMemoryProviderSupplier = StreamTestUtils.FakeMemoryProviderSupplier.class) + public interface MyAiService { + + Multi hi(@MemoryId String id, @UserMessage String query); + + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/StreamTestUtils.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/StreamTestUtils.java new file mode 100644 index 000000000..cbe0e7c59 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/StreamTestUtils.java @@ -0,0 +1,126 @@ +package io.quarkiverse.langchain4j.test.streaming; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.Supplier; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkus.arc.Arc; +import io.smallrye.mutiny.infrastructure.Infrastructure; +import io.vertx.core.Vertx; + +/** + * Utility class for streaming tests. + */ +public class StreamTestUtils { + + public static class FakeMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new MessageWindowChatMemory.Builder() + .id(memoryId) + .maxMessages(10) + .chatMemoryStore(new FakeMemoryStore()) + .build(); + } + }; + } + } + + public static class FakeStreamedChatModelSupplier implements Supplier { + + @Override + public StreamingChatLanguageModel get() { + return new FakeStreamedChatModel(); + } + } + + public static class FakeStreamedChatModel implements StreamingChatLanguageModel { + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + Vertx vertx = Arc.container().select(Vertx.class).get(); + var ctxt = vertx.getOrCreateContext(); + + if (messages.size() > 1) { + var last = (UserMessage) messages.get(messages.size() - 1); + if (last.singleText().equalsIgnoreCase("Second message")) { + if (messages.size() < 3) { + ctxt.runOnContext(x -> handler.onError(new IllegalStateException("Error - no memory"))); + return; + } else { + ctxt.runOnContext(x -> { + handler.onNext("OK!"); + handler.onComplete(Response.from(AiMessage.from(""))); + }); + return; + } + } + } + + ctxt.runOnContext(x1 -> { + handler.onNext("Hi!"); + ctxt.runOnContext(x2 -> { + handler.onNext(" "); + ctxt.runOnContext(x3 -> { + handler.onNext("World!"); + ctxt.runOnContext(x -> handler.onComplete(Response.from(AiMessage.from("")))); + }); + }); + }); + } + } + + public static class FakeMemoryStore implements ChatMemoryStore { + + public static String DC_DATA; + + private static final Map> memories = new ConcurrentHashMap<>(); + + private void checkDuplicatedContext() { + if (DC_DATA != null) { + if (!DC_DATA.equals(Vertx.currentContext().getLocal("DC_DATA"))) { + throw new AssertionError("Expected to be in the same context"); + } + } + } + + @Override + public List getMessages(Object memoryId) { + if (!Infrastructure.canCallerThreadBeBlocked()) { + throw new AssertionError("Expected to be able to block"); + } + checkDuplicatedContext(); + return memories.computeIfAbsent(memoryId, x -> new CopyOnWriteArrayList<>()); + } + + @Override + public void updateMessages(Object memoryId, List messages) { + if (!Infrastructure.canCallerThreadBeBlocked()) { + throw new AssertionError("Expected to be able to block"); + } + memories.put(memoryId, messages); + } + + @Override + public void deleteMessages(Object memoryId) { + if (!Infrastructure.canCallerThreadBeBlocked()) { + throw new AssertionError("Expected to be able to block"); + } + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index 4321739e3..689c58637 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -60,7 +60,7 @@ public final class AiServiceMethodCreateInfo { private OutputTokenAccumulator accumulator; private final LazyValue guardrailsMaxRetry; - private final boolean switchToWorkerThread; + private final boolean switchToWorkerThreadForToolExecution; @RecordableConstructor public AiServiceMethodCreateInfo(String interfaceName, String methodName, @@ -74,7 +74,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName, Optional spanInfo, ResponseSchemaInfo responseSchemaInfo, List toolClassNames, - boolean switchToWorkerThread, + boolean switchToWorkerThreadForToolExecution, List inputGuardrailsClassNames, List outputGuardrailsClassNames, String outputTokenAccumulatorClassName, @@ -108,7 +108,7 @@ public Integer get() { .orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT); } }); - this.switchToWorkerThread = switchToWorkerThread; + this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution; this.responseAugmenterClassName = responseAugmenterClassName; } @@ -238,8 +238,8 @@ public String getUserMessageTemplate() { return userMessageTemplateOpt.orElse(""); } - public boolean isSwitchToWorkerThread() { - return switchToWorkerThread; + public boolean isSwitchToWorkerThreadForToolExecution() { + return switchToWorkerThreadForToolExecution; } public void setResponseAugmenter(Class> augmenter) { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 1c4d31e1c..f66ce91c1 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -152,6 +152,7 @@ public Object implement(Input input) { private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs, QuarkusAiServiceContext context, Audit audit) { + boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread(); Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null); Optional systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs, context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList()); @@ -227,7 +228,7 @@ public Flow.Publisher apply(AugmentationResult ar) { List messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed); var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, finalToolExecutors, ar.contents(), context, memoryId, - methodCreateInfo.isSwitchToWorkerThread()); + methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread); return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, new ResponseAugmenterParams((UserMessage) augmentedUserMessage, memory, ar, methodCreateInfo.getUserMessageTemplate(), @@ -278,7 +279,7 @@ private List messagesToSend(UserMessage augmentedUserMessage, if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) { var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId, - methodCreateInfo.isSwitchToWorkerThread()); + methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread); return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo, new ResponseAugmenterParams(actualUserMessage, chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(), @@ -287,7 +288,7 @@ private List messagesToSend(UserMessage augmentedUserMessage, return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId, - methodCreateInfo.isSwitchToWorkerThread()) + methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread) .plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo)) .map(chunk -> { OutputGuardrailResult result; @@ -785,11 +786,13 @@ private static class TokenStreamMulti extends AbstractMulti implements M private final List contents; private final QuarkusAiServiceContext context; private final Object memoryId; - private final boolean mustSwitchToWorkerThread; + private final boolean switchToWorkerThreadForToolExecution; + private final boolean isCallerRunningOnWorkerThread; public TokenStreamMulti(List messagesToSend, List toolSpecifications, Map toolExecutors, - List contents, QuarkusAiServiceContext context, Object memoryId, boolean mustSwitchToWorkerThread) { + List contents, QuarkusAiServiceContext context, Object memoryId, + boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) { // We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription. this.messagesToSend = messagesToSend; this.toolSpecifications = toolSpecifications; @@ -797,7 +800,8 @@ public TokenStreamMulti(List messagesToSend, List subscriber) { private void createTokenStream(UnicastProcessor processor) { Context ctxt = null; - if (mustSwitchToWorkerThread) { + if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) { // we create or retrieve the current context, to use `executeBlocking` when required. ctxt = VertxContext.getOrCreateDuplicatedContext(); } var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications, - toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread); + toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution, + isCallerRunningOnWorkerThread); TokenStream tokenStream = stream .onNext(processor::onNext) .onComplete(message -> processor.onComplete()) .onError(processor::onError); // This is equivalent to "run subscription on worker thread" - if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) { + if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) { ctxt.executeBlocking(new Callable() { @Override public Void call() { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java index fc520aa77..60d9aa18e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java @@ -7,6 +7,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.function.Consumer; import org.jboss.logging.Logger; @@ -22,7 +24,6 @@ import dev.langchain4j.service.AiServiceContext; import dev.langchain4j.service.tool.ToolExecution; import dev.langchain4j.service.tool.ToolExecutor; -import io.smallrye.mutiny.infrastructure.Infrastructure; import io.vertx.core.Context; /** @@ -49,6 +50,8 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon private final Map toolExecutors; private final Context executionContext; private final boolean mustSwitchToWorkerThread; + private final boolean switchToWorkerForEmission; + private final ExecutorService executor; QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, @@ -59,7 +62,10 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon List temporaryMemory, TokenUsage tokenUsage, List toolSpecifications, - Map toolExecutors, boolean mustSwitchToWorkerThread, Context cxtx) { + Map toolExecutors, + boolean mustSwitchToWorkerThread, + boolean switchToWorkerForEmission, + Context cxtx) { this.context = ensureNotNull(context, "context"); this.memoryId = ensureNotNull(memoryId, "memoryId"); @@ -76,36 +82,82 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; this.executionContext = cxtx; + this.switchToWorkerForEmission = switchToWorkerForEmission; + if (executionContext == null) { + // We do not have a context, but we still need to make sure we are not blocking the event loop and ordered + // is respected. + executor = Executors.newSingleThreadExecutor(); + } else { + executor = null; + } + } + + public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, Consumer tokenHandler, + Consumer toolExecuteHandler, Consumer> completionHandler, + Consumer errorHandler, List temporaryMemory, TokenUsage sum, + List toolSpecifications, Map toolExecutors, + boolean mustSwitchToWorkerThread, boolean switchToWorkerForEmission, Context executionContext, + ExecutorService executor) { + this.context = context; + this.memoryId = memoryId; + this.tokenHandler = tokenHandler; + this.toolExecuteHandler = toolExecuteHandler; + this.completionHandler = completionHandler; + this.errorHandler = errorHandler; + this.temporaryMemory = temporaryMemory; + this.tokenUsage = sum; + this.toolSpecifications = toolSpecifications; + this.toolExecutors = toolExecutors; + this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; + this.switchToWorkerForEmission = switchToWorkerForEmission; + this.executionContext = executionContext; + this.executor = executor; } @Override public void onNext(String token) { - tokenHandler.accept(token); + execute(new Runnable() { + @Override + public void run() { + tokenHandler.accept(token); + } + }); + } private void executeTools(Runnable runnable) { if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) { - if (executionContext != null) { - executionContext.executeBlocking(new Callable() { - @Override - public Object call() { - runnable.run(); - return null; - } - }); - } else { - // We do not have a context, switching to worker thread. - Infrastructure.getDefaultWorkerPool().execute(runnable); - } + executeOnWorkerThread(runnable); } else { runnable.run(); } } + private void execute(Runnable runnable) { + if (switchToWorkerForEmission && Context.isOnEventLoopThread()) { + executeOnWorkerThread(runnable); + } else { + runnable.run(); + } + } + + private void executeOnWorkerThread(Runnable runnable) { + if (executionContext != null) { + executionContext.executeBlocking(new Callable() { + @Override + public Object call() throws Exception { + runnable.run(); + return null; + } + }, true); + } else { + executor.submit(runnable); + } + } + @Override public void onComplete(Response response) { AiMessage aiMessage = response.content(); - addToMemory(aiMessage); if (aiMessage.hasToolExecutionRequests()) { // Tools execution may block the caller thread. When the caller thread is the event loop thread, and @@ -113,6 +165,7 @@ public void onComplete(Response response) { executeTools(new Runnable() { @Override public void run() { + addToMemory(aiMessage); for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { String toolName = toolExecutionRequest.name(); ToolExecutor toolExecutor = toolExecutors.get(toolName); @@ -143,19 +196,36 @@ public void run() { TokenUsage.sum(tokenUsage, response.tokenUsage()), toolSpecifications, toolExecutors, - mustSwitchToWorkerThread, executionContext)); + mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor)); } }); } else { if (completionHandler != null) { - completionHandler.accept(Response.from( - aiMessage, - TokenUsage.sum(tokenUsage, response.tokenUsage()), - response.finishReason())); + Runnable runnable = new Runnable() { + @Override + public void run() { + try { + addToMemory(aiMessage); + completionHandler.accept(Response.from( + aiMessage, + TokenUsage.sum(tokenUsage, response.tokenUsage()), + response.finishReason())); + } finally { + shutdown(); // Terminal event, we can shutdown the executor + } + } + }; + execute(runnable); } } } + private void shutdown() { + if (executor != null) { + executor.shutdown(); + } + } + private void addToMemory(ChatMessage chatMessage) { if (context.hasChatMemory()) { context.chatMemory(memoryId).add(chatMessage); @@ -173,12 +243,19 @@ private List messagesToSend(Object memoryId) { @Override public void onError(Throwable error) { if (errorHandler != null) { - try { - errorHandler.accept(error); - } catch (Exception e) { - log.error("While handling the following error...", error); - log.error("...the following error happened", e); - } + execute(new Runnable() { + @Override + public void run() { + try { + errorHandler.accept(error); + } catch (Exception e) { + log.error("While handling the following error...", error); + log.error("...the following error happened", e); + } finally { + shutdown(); // Terminal event, we can shutdown the executor + } + } + }); } else { log.warn("Ignored error", error); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java index fa8939938..57a9b8276 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java @@ -39,7 +39,8 @@ public class QuarkusAiServiceTokenStream implements TokenStream { private final AiServiceContext context; private final Object memoryId; private final Context cxtx; - private final boolean mustSwitchToWorkerThread; + private final boolean switchToWorkerThreadForToolExecution; + private final boolean switchToWorkerForEmission; private Consumer tokenHandler; private Consumer> contentsHandler; @@ -59,7 +60,8 @@ public QuarkusAiServiceTokenStream(List messages, Map toolExecutors, List retrievedContents, AiServiceContext context, - Object memoryId, Context ctxt, boolean mustSwitchToWorkerThread) { + Object memoryId, Context ctxt, boolean switchToWorkerThreadForToolExecution, + boolean switchToWorkerForEmission) { this.messages = ensureNotEmpty(messages, "messages"); this.toolSpecifications = copyIfNotNull(toolSpecifications); this.toolExecutors = copyIfNotNull(toolExecutors); @@ -68,7 +70,8 @@ public QuarkusAiServiceTokenStream(List messages, this.memoryId = ensureNotNull(memoryId, "memoryId"); ensureNotNull(context.streamingChatModel, "streamingChatModel"); this.cxtx = ctxt; // If set, it means we need to handle the context propagation. - this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; // If true, we need to switch to a worker thread to execute tools. + this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution; // If true, we need to switch to a worker thread to execute tools. + this.switchToWorkerForEmission = switchToWorkerForEmission; } @Override @@ -127,7 +130,8 @@ public void start() { new TokenUsage(), toolSpecifications, toolExecutors, - mustSwitchToWorkerThread, + switchToWorkerThreadForToolExecution, + switchToWorkerForEmission, cxtx); if (contentsHandler != null && retrievedContents != null) {