Skip to content

Commit

Permalink
Merge pull request #1161 from andreadimaio/main
Browse files Browse the repository at this point in the history
Make sure to avoid generating the schema if the response-schema property is set to false
  • Loading branch information
geoand authored Dec 14, 2024
2 parents 37bd953 + 69979be commit 6ec8ea5
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC
audit.initialMessages(systemMessage, userMessage);
}

//TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image?
// TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image?
AugmentationResult augmentationResult = null;

// TODO: we can only support input guardrails for now as it is tied to AiMessage
Expand Down Expand Up @@ -644,14 +644,16 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
.formatted(ResponseSchemaUtil.placeholder(), createInfo.getInterfaceName()));
}

// No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage.
if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) {
templateText = templateText.concat(ResponseSchemaUtil.placeholder());
if (createInfo.getResponseSchemaInfo().enabled()) {
// No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage.
if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) {
templateText = templateText.concat(ResponseSchemaUtil.placeholder());
}

templateVariables.put(ResponseSchemaUtil.templateParam(),
createInfo.getResponseSchemaInfo().outputFormatInstructions());
}

// we do not need to apply the instructions as they have already been added to the template text at build time
templateVariables.put(ResponseSchemaUtil.templateParam(),
createInfo.getResponseSchemaInfo().outputFormatInstructions());
Prompt prompt = PromptTemplate.from(templateText).apply(templateVariables);
return createUserMessage(userName, imageContent, prompt.text());

Expand All @@ -667,7 +669,8 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic

String text = toString(argValue);
return createUserMessage(userName, imageContent,
text.concat(supportsJsonSchema ? "" : createInfo.getResponseSchemaInfo().outputFormatInstructions()));
text.concat(supportsJsonSchema || !createInfo.getResponseSchemaInfo().enabled() ? ""
: createInfo.getResponseSchemaInfo().outputFormatInstructions()));
} else {
throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName()
+ "'. Please contact the maintainers");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package io.quarkiverse.langchain4j.watsonx.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.Date;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

Expand All @@ -11,38 +15,112 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;

public class ResponseSchemaOffTest {
public class ResponseSchemaOffTest extends WireMockAbstract {

@RegisterExtension
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER)
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER)
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY)
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID)
.overrideConfigKey("quarkus.langchain4j.watsonx.mode", "generation")
.overrideConfigKey("quarkus.langchain4j.response-schema", "false")
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@RegisterAiService
@Override
void handlerBeforeEach() {
mockServers.mockIAMBuilder(200)
.grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType())
.response(WireMockUtil.BEARER_TOKEN, new Date())
.build();
}

@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
@Singleton
interface OnMethodAIService {
String poem(@UserMessage String message, @V("topic") String topic);
String poem1(@UserMessage String message, @V("topic") String topic);

Poem poem2(@UserMessage String message);

@UserMessage("{message}")
Poem poem3(String message);

@SystemMessage("SystemMessage")
@UserMessage("{message}")
Poem poem4(String message);

public record Poem(String text) {
};
}

@Inject
OnMethodAIService onMethodAIService;

static String POEM_RESPONSE = """
{
"model_id": "mistralai/mistral-large",
"created_at": "2024-01-21T17:06:14.052Z",
"results": [
{
"generated_text": "{ \\\"text\\\": \\\"Poem\\\" }",
"generated_token_count": 5,
"input_token_count": 50,
"stop_reason": "eos_token",
"seed": 2123876088
}
]
}
""";

@Test
void on_method_ai_service() throws Exception {
void test_poem_1() throws Exception {
var ex = assertThrows(RuntimeException.class,
() -> onMethodAIService.poem("{response_schema} Generate a poem about {topic}", "dog"));
() -> onMethodAIService.poem1("{response_schema} Generate a poem about {topic}", "dog"));
assertEquals(
"The {response_schema} placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: io.quarkiverse.langchain4j.watsonx.deployment.ResponseSchemaOffTest$OnMethodAIService",
ex.getMessage());

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API)
.build();

assertEquals("AI Response", onMethodAIService.poem1("Generate a poem about {topic}", "dog"));
}

@Test
void test_poem_2() {
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
.response(POEM_RESPONSE)
.build();

assertEquals("Poem", onMethodAIService.poem2("Generate a poem about dog").text);
}

@Test
void test_poem_3() {
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
.response(POEM_RESPONSE)
.build();

assertEquals("Poem", onMethodAIService.poem3("Generate a poem about dog").text);
}

@Test
void test_poem_4() {
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("SystemMessage\nGenerate a poem about dog")))
.response(POEM_RESPONSE)
.build();

assertEquals("Poem", onMethodAIService.poem4("Generate a poem about dog").text);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.MappingBuilder;
import com.github.tomakehurst.wiremock.matching.StringValuePattern;
import com.github.tomakehurst.wiremock.stubbing.StubMapping;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
Expand Down Expand Up @@ -282,6 +284,11 @@ public WatsonxBuilder body(String body) {
return this;
}

public WatsonxBuilder body(StringValuePattern stringValuePattern) {
builder.withRequestBody(stringValuePattern);
return this;
}

public WatsonxBuilder token(String token) {
this.token = token;
return this;
Expand All @@ -297,8 +304,8 @@ public WatsonxBuilder response(String response) {
return this;
}

public void build() {
watsonServer.stubFor(
public StubMapping build() {
return watsonServer.stubFor(
builder
.withHeader("Authorization", equalTo("Bearer %s".formatted(token)))
.willReturn(aResponse()
Expand Down

0 comments on commit 6ec8ea5

Please sign in to comment.