Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure to avoid generating the schema if the response-schema property is set to false #1161

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC
audit.initialMessages(systemMessage, userMessage);
}

//TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image?
// TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image?
AugmentationResult augmentationResult = null;

// TODO: we can only support input guardrails for now as it is tied to AiMessage
Expand Down Expand Up @@ -644,14 +644,16 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
.formatted(ResponseSchemaUtil.placeholder(), createInfo.getInterfaceName()));
}

// No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage.
if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) {
templateText = templateText.concat(ResponseSchemaUtil.placeholder());
if (createInfo.getResponseSchemaInfo().enabled()) {
// No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage.
if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) {
templateText = templateText.concat(ResponseSchemaUtil.placeholder());
}

templateVariables.put(ResponseSchemaUtil.templateParam(),
createInfo.getResponseSchemaInfo().outputFormatInstructions());
}

// we do not need to apply the instructions as they have already been added to the template text at build time
templateVariables.put(ResponseSchemaUtil.templateParam(),
createInfo.getResponseSchemaInfo().outputFormatInstructions());
Prompt prompt = PromptTemplate.from(templateText).apply(templateVariables);
return createUserMessage(userName, imageContent, prompt.text());

Expand All @@ -667,7 +669,8 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic

String text = toString(argValue);
return createUserMessage(userName, imageContent,
text.concat(supportsJsonSchema ? "" : createInfo.getResponseSchemaInfo().outputFormatInstructions()));
text.concat(supportsJsonSchema || !createInfo.getResponseSchemaInfo().enabled() ? ""
: createInfo.getResponseSchemaInfo().outputFormatInstructions()));
} else {
throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName()
+ "'. Please contact the maintainers");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package io.quarkiverse.langchain4j.watsonx.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.Date;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

Expand All @@ -11,38 +15,112 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;

public class ResponseSchemaOffTest {
public class ResponseSchemaOffTest extends WireMockAbstract {

@RegisterExtension
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER)
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER)
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY)
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID)
.overrideConfigKey("quarkus.langchain4j.watsonx.mode", "generation")
.overrideConfigKey("quarkus.langchain4j.response-schema", "false")
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@RegisterAiService
@Override
void handlerBeforeEach() {
mockServers.mockIAMBuilder(200)
.grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType())
.response(WireMockUtil.BEARER_TOKEN, new Date())
.build();
}

@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
@Singleton
interface OnMethodAIService {
String poem(@UserMessage String message, @V("topic") String topic);
String poem1(@UserMessage String message, @V("topic") String topic);

Poem poem2(@UserMessage String message);

@UserMessage("{message}")
Poem poem3(String message);

@SystemMessage("SystemMessage")
@UserMessage("{message}")
Poem poem4(String message);

public record Poem(String text) {
};
}

@Inject
OnMethodAIService onMethodAIService;

static String POEM_RESPONSE = """
{
"model_id": "mistralai/mistral-large",
"created_at": "2024-01-21T17:06:14.052Z",
"results": [
{
"generated_text": "{ \\\"text\\\": \\\"Poem\\\" }",
"generated_token_count": 5,
"input_token_count": 50,
"stop_reason": "eos_token",
"seed": 2123876088
}
]
}
""";

@Test
void on_method_ai_service() throws Exception {
void test_poem_1() throws Exception {
var ex = assertThrows(RuntimeException.class,
() -> onMethodAIService.poem("{response_schema} Generate a poem about {topic}", "dog"));
() -> onMethodAIService.poem1("{response_schema} Generate a poem about {topic}", "dog"));
assertEquals(
"The {response_schema} placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: io.quarkiverse.langchain4j.watsonx.deployment.ResponseSchemaOffTest$OnMethodAIService",
ex.getMessage());

mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API)
.build();

assertEquals("AI Response", onMethodAIService.poem1("Generate a poem about {topic}", "dog"));
}

@Test
void test_poem_2() {
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
.response(POEM_RESPONSE)
.build();

assertEquals("Poem", onMethodAIService.poem2("Generate a poem about dog").text);
}

@Test
void test_poem_3() {
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
.response(POEM_RESPONSE)
.build();

assertEquals("Poem", onMethodAIService.poem3("Generate a poem about dog").text);
}

@Test
void test_poem_4() {
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
.body(matchingJsonPath("$.input", equalTo("SystemMessage\nGenerate a poem about dog")))
.response(POEM_RESPONSE)
.build();

assertEquals("Poem", onMethodAIService.poem4("Generate a poem about dog").text);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.MappingBuilder;
import com.github.tomakehurst.wiremock.matching.StringValuePattern;
import com.github.tomakehurst.wiremock.stubbing.StubMapping;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
Expand Down Expand Up @@ -282,6 +284,11 @@ public WatsonxBuilder body(String body) {
return this;
}

public WatsonxBuilder body(StringValuePattern stringValuePattern) {
builder.withRequestBody(stringValuePattern);
return this;
}

public WatsonxBuilder token(String token) {
this.token = token;
return this;
Expand All @@ -297,8 +304,8 @@ public WatsonxBuilder response(String response) {
return this;
}

public void build() {
watsonServer.stubFor(
public StubMapping build() {
return watsonServer.stubFor(
builder
.withHeader("Authorization", equalTo("Bearer %s".formatted(token)))
.willReturn(aResponse()
Expand Down
Loading