From 69979becd2af20d5e843fe71370747fd59c32ac8 Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Sat, 14 Dec 2024 12:40:02 +0100 Subject: [PATCH] Make sure to avoid generating the schema if the quarkus.langchain4j.response-schema property is set to false --- .../AiServiceMethodImplementationSupport.java | 19 ++-- .../deployment/ResponseSchemaOffTest.java | 88 +++++++++++++++++-- .../watsonx/deployment/WireMockUtil.java | 11 ++- 3 files changed, 103 insertions(+), 15 deletions(-) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index f66ce91c1..4c5836466 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -464,7 +464,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC audit.initialMessages(systemMessage, userMessage); } - //TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image? + // TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image? AugmentationResult augmentationResult = null; // TODO: we can only support input guardrails for now as it is tied to AiMessage @@ -644,14 +644,16 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic .formatted(ResponseSchemaUtil.placeholder(), createInfo.getInterfaceName())); } - // No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage. - if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) { - templateText = templateText.concat(ResponseSchemaUtil.placeholder()); + if (createInfo.getResponseSchemaInfo().enabled()) { + // No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage. + if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) { + templateText = templateText.concat(ResponseSchemaUtil.placeholder()); + } + + templateVariables.put(ResponseSchemaUtil.templateParam(), + createInfo.getResponseSchemaInfo().outputFormatInstructions()); } - // we do not need to apply the instructions as they have already been added to the template text at build time - templateVariables.put(ResponseSchemaUtil.templateParam(), - createInfo.getResponseSchemaInfo().outputFormatInstructions()); Prompt prompt = PromptTemplate.from(templateText).apply(templateVariables); return createUserMessage(userName, imageContent, prompt.text()); @@ -667,7 +669,8 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic String text = toString(argValue); return createUserMessage(userName, imageContent, - text.concat(supportsJsonSchema ? "" : createInfo.getResponseSchemaInfo().outputFormatInstructions())); + text.concat(supportsJsonSchema || !createInfo.getResponseSchemaInfo().enabled() ? "" + : createInfo.getResponseSchemaInfo().outputFormatInstructions())); } else { throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName() + "'. Please contact the maintainers"); diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOffTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOffTest.java index 7e9e11aa3..280d7e009 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOffTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ResponseSchemaOffTest.java @@ -1,8 +1,12 @@ package io.quarkiverse.langchain4j.watsonx.deployment; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.util.Date; + import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -11,12 +15,13 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.V; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkus.test.QuarkusUnitTest; -public class ResponseSchemaOffTest { +public class ResponseSchemaOffTest extends WireMockAbstract { @RegisterExtension static QuarkusUnitTest unitTest = new QuarkusUnitTest() @@ -24,25 +29,98 @@ public class ResponseSchemaOffTest { .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY) .overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID) + .overrideConfigKey("quarkus.langchain4j.watsonx.mode", "generation") .overrideConfigKey("quarkus.langchain4j.response-schema", "false") .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); - @RegisterAiService + @Override + void handlerBeforeEach() { + mockServers.mockIAMBuilder(200) + .grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType()) + .response(WireMockUtil.BEARER_TOKEN, new Date()) + .build(); + } + + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) @Singleton interface OnMethodAIService { - String poem(@UserMessage String message, @V("topic") String topic); + String poem1(@UserMessage String message, @V("topic") String topic); + + Poem poem2(@UserMessage String message); + + @UserMessage("{message}") + Poem poem3(String message); + + @SystemMessage("SystemMessage") + @UserMessage("{message}") + Poem poem4(String message); + + public record Poem(String text) { + }; } @Inject OnMethodAIService onMethodAIService; + static String POEM_RESPONSE = """ + { + "model_id": "mistralai/mistral-large", + "created_at": "2024-01-21T17:06:14.052Z", + "results": [ + { + "generated_text": "{ \\\"text\\\": \\\"Poem\\\" }", + "generated_token_count": 5, + "input_token_count": 50, + "stop_reason": "eos_token", + "seed": 2123876088 + } + ] + } + """; + @Test - void on_method_ai_service() throws Exception { + void test_poem_1() throws Exception { var ex = assertThrows(RuntimeException.class, - () -> onMethodAIService.poem("{response_schema} Generate a poem about {topic}", "dog")); + () -> onMethodAIService.poem1("{response_schema} Generate a poem about {topic}", "dog")); assertEquals( "The {response_schema} placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: io.quarkiverse.langchain4j.watsonx.deployment.ResponseSchemaOffTest$OnMethodAIService", ex.getMessage()); + + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .body(matchingJsonPath("$.input", equalTo("Generate a poem about dog"))) + .response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API) + .build(); + + assertEquals("AI Response", onMethodAIService.poem1("Generate a poem about {topic}", "dog")); } + @Test + void test_poem_2() { + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .body(matchingJsonPath("$.input", equalTo("Generate a poem about dog"))) + .response(POEM_RESPONSE) + .build(); + + assertEquals("Poem", onMethodAIService.poem2("Generate a poem about dog").text); + } + + @Test + void test_poem_3() { + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .body(matchingJsonPath("$.input", equalTo("Generate a poem about dog"))) + .response(POEM_RESPONSE) + .build(); + + assertEquals("Poem", onMethodAIService.poem3("Generate a poem about dog").text); + } + + @Test + void test_poem_4() { + mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200) + .body(matchingJsonPath("$.input", equalTo("SystemMessage\nGenerate a poem about dog"))) + .response(POEM_RESPONSE) + .build(); + + assertEquals("Poem", onMethodAIService.poem4("Generate a poem about dog").text); + } } diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java index 477b5ea1e..32cfb1f52 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java @@ -15,6 +15,8 @@ import com.github.tomakehurst.wiremock.WireMockServer; import com.github.tomakehurst.wiremock.client.MappingBuilder; +import com.github.tomakehurst.wiremock.matching.StringValuePattern; +import com.github.tomakehurst.wiremock.stubbing.StubMapping; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; @@ -282,6 +284,11 @@ public WatsonxBuilder body(String body) { return this; } + public WatsonxBuilder body(StringValuePattern stringValuePattern) { + builder.withRequestBody(stringValuePattern); + return this; + } + public WatsonxBuilder token(String token) { this.token = token; return this; @@ -297,8 +304,8 @@ public WatsonxBuilder response(String response) { return this; } - public void build() { - watsonServer.stubFor( + public StubMapping build() { + return watsonServer.stubFor( builder .withHeader("Authorization", equalTo("Bearer %s".formatted(token))) .willReturn(aResponse()