-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Blocking Memory Store Usage in Streamed Mode
This commit addresses issues with using the blocking memory store in streamed responses. * Ensures the execution captures whether the caller is running on a worker thread. * Switches to worker threads for every emission and completion event when the caller is using a worker thread. * Relies on executeBlocking to propagate the context automatically when possible. Note: * The blocking memory store cannot be used when invoked on the event loop. It now requires that the caller must be on a worker thread.
- Loading branch information
1 parent
5ec8e05
commit 0690a53
Showing
7 changed files
with
379 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
117 changes: 117 additions & 0 deletions
117
.../io/quarkiverse/langchain4j/test/streaming/BlockingMemoryStoreOnStreamedResponseTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package io.quarkiverse.langchain4j.test.streaming; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
import java.util.List; | ||
import java.util.UUID; | ||
import java.util.concurrent.CountDownLatch; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.inject.Inject; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.RepeatedTest; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.arc.Arc; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
import io.smallrye.common.vertx.VertxContext; | ||
import io.smallrye.mutiny.Multi; | ||
import io.vertx.core.Context; | ||
import io.vertx.core.Vertx; | ||
|
||
public class BlockingMemoryStoreOnStreamedResponseTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(StreamTestUtils.class)); | ||
|
||
@Inject | ||
MyAiService service; | ||
|
||
@RepeatedTest(100) // Verify that the order is preserved. | ||
@ActivateRequestContext | ||
void testFromWorkerThread() { | ||
// We are on a worker thread. | ||
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely(); | ||
// We cannot guarantee the order, as we do not have a context. | ||
assertThat(list).containsExactly("Hi!", " ", "World!"); | ||
|
||
list = service.hi("123", "Second message").collect().asList().await().indefinitely(); | ||
assertThat(list).containsExactly("OK!"); | ||
} | ||
|
||
@BeforeEach | ||
void cleanup() { | ||
StreamTestUtils.FakeMemoryStore.DC_DATA = null; | ||
} | ||
|
||
@RepeatedTest(10) | ||
void testFromDuplicatedContextThread() throws InterruptedException { | ||
Context context = VertxContext.getOrCreateDuplicatedContext(vertx); | ||
CountDownLatch latch = new CountDownLatch(1); | ||
context.executeBlocking(v -> { | ||
try { | ||
Arc.container().requestContext().activate(); | ||
var value = UUID.randomUUID().toString(); | ||
StreamTestUtils.FakeMemoryStore.DC_DATA = value; | ||
Vertx.currentContext().putLocal("DC_DATA", value); | ||
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely(); | ||
assertThat(list).containsExactly("Hi!", " ", "World!"); | ||
Arc.container().requestContext().deactivate(); | ||
|
||
Arc.container().requestContext().activate(); | ||
|
||
list = service.hi("123", "Second message").collect().asList().await().indefinitely(); | ||
assertThat(list).containsExactly("OK!"); | ||
latch.countDown(); | ||
|
||
} finally { | ||
Arc.container().requestContext().deactivate(); | ||
Vertx.currentContext().removeLocal("DC_DATA"); | ||
|
||
} | ||
}, false); | ||
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); | ||
} | ||
|
||
@Inject | ||
Vertx vertx; | ||
|
||
@Test | ||
void testFromEventLoopThread() throws InterruptedException { | ||
CountDownLatch latch = new CountDownLatch(1); | ||
Context context = vertx.getOrCreateContext(); | ||
context.runOnContext(v -> { | ||
Arc.container().requestContext().activate(); | ||
try { | ||
service.hi("123", "Say hello").collect().asList() | ||
.subscribe().asCompletionStage(); | ||
} catch (Exception e) { | ||
assertThat(e) | ||
.isNotNull() | ||
.hasMessageContaining("Expected to be able to block"); | ||
} finally { | ||
Arc.container().requestContext().deactivate(); | ||
latch.countDown(); | ||
} | ||
}); | ||
latch.await(); | ||
} | ||
|
||
@RegisterAiService(streamingChatLanguageModelSupplier = StreamTestUtils.FakeStreamedChatModelSupplier.class, chatMemoryProviderSupplier = StreamTestUtils.FakeMemoryProviderSupplier.class) | ||
public interface MyAiService { | ||
|
||
Multi<String> hi(@MemoryId String id, @UserMessage String query); | ||
|
||
} | ||
|
||
} |
126 changes: 126 additions & 0 deletions
126
core/deployment/src/test/java/io/quarkiverse/langchain4j/test/streaming/StreamTestUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package io.quarkiverse.langchain4j.test.streaming; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.CopyOnWriteArrayList; | ||
import java.util.function.Supplier; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.memory.ChatMemory; | ||
import dev.langchain4j.memory.chat.ChatMemoryProvider; | ||
import dev.langchain4j.memory.chat.MessageWindowChatMemory; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.store.memory.chat.ChatMemoryStore; | ||
import io.quarkus.arc.Arc; | ||
import io.smallrye.mutiny.infrastructure.Infrastructure; | ||
import io.vertx.core.Vertx; | ||
|
||
/** | ||
* Utility class for streaming tests. | ||
*/ | ||
public class StreamTestUtils { | ||
|
||
public static class FakeMemoryProviderSupplier implements Supplier<ChatMemoryProvider> { | ||
@Override | ||
public ChatMemoryProvider get() { | ||
return new ChatMemoryProvider() { | ||
@Override | ||
public ChatMemory get(Object memoryId) { | ||
return new MessageWindowChatMemory.Builder() | ||
.id(memoryId) | ||
.maxMessages(10) | ||
.chatMemoryStore(new FakeMemoryStore()) | ||
.build(); | ||
} | ||
}; | ||
} | ||
} | ||
|
||
public static class FakeStreamedChatModelSupplier implements Supplier<StreamingChatLanguageModel> { | ||
|
||
@Override | ||
public StreamingChatLanguageModel get() { | ||
return new FakeStreamedChatModel(); | ||
} | ||
} | ||
|
||
public static class FakeStreamedChatModel implements StreamingChatLanguageModel { | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
Vertx vertx = Arc.container().select(Vertx.class).get(); | ||
var ctxt = vertx.getOrCreateContext(); | ||
|
||
if (messages.size() > 1) { | ||
var last = (UserMessage) messages.get(messages.size() - 1); | ||
if (last.singleText().equalsIgnoreCase("Second message")) { | ||
if (messages.size() < 3) { | ||
ctxt.runOnContext(x -> handler.onError(new IllegalStateException("Error - no memory"))); | ||
return; | ||
} else { | ||
ctxt.runOnContext(x -> { | ||
handler.onNext("OK!"); | ||
handler.onComplete(Response.from(AiMessage.from(""))); | ||
}); | ||
return; | ||
} | ||
} | ||
} | ||
|
||
ctxt.runOnContext(x1 -> { | ||
handler.onNext("Hi!"); | ||
ctxt.runOnContext(x2 -> { | ||
handler.onNext(" "); | ||
ctxt.runOnContext(x3 -> { | ||
handler.onNext("World!"); | ||
ctxt.runOnContext(x -> handler.onComplete(Response.from(AiMessage.from("")))); | ||
}); | ||
}); | ||
}); | ||
} | ||
} | ||
|
||
public static class FakeMemoryStore implements ChatMemoryStore { | ||
|
||
public static String DC_DATA; | ||
|
||
private static final Map<Object, List<ChatMessage>> memories = new ConcurrentHashMap<>(); | ||
|
||
private void checkDuplicatedContext() { | ||
if (DC_DATA != null) { | ||
if (!DC_DATA.equals(Vertx.currentContext().getLocal("DC_DATA"))) { | ||
throw new AssertionError("Expected to be in the same context"); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public List<ChatMessage> getMessages(Object memoryId) { | ||
if (!Infrastructure.canCallerThreadBeBlocked()) { | ||
throw new AssertionError("Expected to be able to block"); | ||
} | ||
checkDuplicatedContext(); | ||
return memories.computeIfAbsent(memoryId, x -> new CopyOnWriteArrayList<>()); | ||
} | ||
|
||
@Override | ||
public void updateMessages(Object memoryId, List<ChatMessage> messages) { | ||
if (!Infrastructure.canCallerThreadBeBlocked()) { | ||
throw new AssertionError("Expected to be able to block"); | ||
} | ||
memories.put(memoryId, messages); | ||
} | ||
|
||
@Override | ||
public void deleteMessages(Object memoryId) { | ||
if (!Infrastructure.canCallerThreadBeBlocked()) { | ||
throw new AssertionError("Expected to be able to block"); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.