From 148c6d525e74bedd6425523e145187718ab22fd4 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Thu, 12 Dec 2024 18:40:28 +0200 Subject: [PATCH] Allow Rest Client and AI Service to be used as tools --- core/deployment/pom.xml | 5 + .../deployment/AiServicesProcessor.java | 62 ++++- .../DeclarativeAiServiceBuildItem.java | 10 +- .../langchain4j/deployment/DotNames.java | 3 + .../langchain4j/deployment/ToolProcessor.java | 26 ++- .../devui/LangChain4jDevUIProcessor.java | 2 +- .../items/ToolQualifierProvider.java | 30 +++ .../runtime/AiServicesRecorder.java | 18 +- .../DeclarativeAiServiceCreateInfo.java | 6 +- .../openai/openai-vanilla/deployment/pom.xml | 5 + .../aiservices/RestClientToolTest.java | 219 ++++++++++++++++++ 11 files changed, 360 insertions(+), 26 deletions(-) create mode 100644 core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java create mode 100644 model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml index ab47f8411..4333b498d 100644 --- a/core/deployment/pom.xml +++ b/core/deployment/pom.xml @@ -46,6 +46,11 @@ true + + org.eclipse.microprofile.rest.client + microprofile-rest-client-api + + io.quarkus quarkus-vertx-http-dev-ui-tests diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index b8fed660a..f595405a8 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -43,8 +43,10 @@ import jakarta.annotation.PreDestroy; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.spi.DeploymentException; +import jakarta.enterprise.util.AnnotationLiteral; import jakarta.inject.Inject; +import org.eclipse.microprofile.rest.client.inject.RestClient; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTarget; import org.jboss.jandex.AnnotationValue; @@ -77,6 +79,7 @@ import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; +import io.quarkiverse.langchain4j.deployment.items.ToolQualifierProvider; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; @@ -262,11 +265,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, chatModelNames.add(chatModelName); } - List toolDotNames = Collections.emptyList(); + List toolClassInfos = Collections.emptyList(); AnnotationValue toolsInstance = instance.value("tools"); if (toolsInstance != null) { - toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name) - .collect(Collectors.toList()); + toolClassInfos = Arrays.stream(toolsInstance.asClassArray()).map(t -> { + var ci = index.getClassByName(t.name()); + if (ci == null) { + throw new IllegalArgumentException("Cannot find class " + t.name() + + " in index. Please make sure it's a valid CDI bean known to Quarkus"); + } + return ci; + }) + .toList(); } // the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean @@ -397,7 +407,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, declarativeAiServiceClassInfo, chatLanguageModelSupplierClassDotName, streamingChatLanguageModelSupplierClassDotName, - toolDotNames, + toolClassInfos, chatMemoryProviderSupplierClassDotName, retrieverClassDotName, retrievalAugmentorSupplierClassName, @@ -476,11 +486,27 @@ private boolean isImageOrImageResultResult(Type returnType) { return false; } + @BuildStep + public void toolQualifiers(BuildProducer producer) { + producer.produce(new ToolQualifierProvider.BuildItem(new ToolQualifierProvider() { + @Override + public boolean supports(ClassInfo classInfo) { + return classInfo.hasAnnotation(DotNames.REGISTER_REST_CLIENT); + } + + @Override + public AnnotationLiteral qualifier(ClassInfo classInfo) { + return new RestClient.RestClientLiteral(); + } + })); + } + @BuildStep @Record(ExecutionTime.STATIC_INIT) public void handleDeclarativeServices(AiServicesRecorder recorder, List declarativeAiServiceItems, List selectedChatModelProvider, + List toolQualifierProviderItems, BuildProducer syntheticBeanProducer, BuildProducer unremovableProducer) { @@ -507,7 +533,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, ? bi.getStreamingChatLanguageModelSupplierClassDotName().toString() : null); - List toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList()); + List toolQualifierProviders = toolQualifierProviderItems.stream().map( + ToolQualifierProvider.BuildItem::getProvider).toList(); + Map> toolToQualifierMap = new HashMap<>(); + for (ClassInfo ci : bi.getToolClassInfos()) { + AnnotationLiteral qualifier = null; + for (ToolQualifierProvider provider : toolQualifierProviders) { + if (provider.supports(ci)) { + qualifier = provider.qualifier(ci); + break; + } + } + toolToQualifierMap.put(ci.name().toString(), qualifier); + } String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null ? bi.getToolProviderClassDotName().toString() @@ -597,7 +635,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, serviceClassName, chatLanguageModelSupplierClassName, streamingChatLanguageModelSupplierClassName, - toolClassNames, + toolToQualifierMap, toolProviderSupplierClassName, chatMemoryProviderSupplierClassName, retrieverClassName, retrievalAugmentorSupplierClassName, @@ -639,12 +677,16 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, needsChatModelBean = true; } - if (!toolClassNames.isEmpty()) { - for (String toolClassName : toolClassNames) { - DotName dotName = DotName.createSimple(toolClassName); + for (var entry : toolToQualifierMap.entrySet()) { + DotName dotName = DotName.createSimple(entry.getKey()); + AnnotationLiteral qualifier = entry.getValue(); + if (qualifier == null) { configurator.addInjectionPoint(ClassType.create(dotName)); - allToolNames.add(dotName); + } else { + configurator.addInjectionPoint(ClassType.create(dotName), + AnnotationInstance.builder(qualifier.annotationType()).build()); } + allToolNames.add(dotName); } if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index 6fd1fc997..03449e2f4 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -16,7 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final ClassInfo serviceClassInfo; private final DotName chatLanguageModelSupplierClassDotName; private final DotName streamingChatLanguageModelSupplierClassDotName; - private final List toolDotNames; + private final List toolClassInfos; private final DotName toolProviderClassDotName; private final DotName chatMemoryProviderSupplierClassDotName; @@ -37,7 +37,7 @@ public DeclarativeAiServiceBuildItem( ClassInfo serviceClassInfo, DotName chatLanguageModelSupplierClassDotName, DotName streamingChatLanguageModelSupplierClassDotName, - List toolDotNames, + List toolClassInfos, DotName chatMemoryProviderSupplierClassDotName, DotName retrieverClassDotName, DotName retrievalAugmentorSupplierClassDotName, @@ -55,7 +55,7 @@ public DeclarativeAiServiceBuildItem( this.serviceClassInfo = serviceClassInfo; this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName; this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName; - this.toolDotNames = toolDotNames; + this.toolClassInfos = toolClassInfos; this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName; this.retrieverClassDotName = retrieverClassDotName; this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName; @@ -84,8 +84,8 @@ public DotName getStreamingChatLanguageModelSupplierClassDotName() { return streamingChatLanguageModelSupplierClassDotName; } - public List getToolDotNames() { - return toolDotNames; + public List getToolClassInfos() { + return toolClassInfos; } public DotName getChatMemoryProviderSupplierClassDotName() { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java index 49b09380d..af4591915 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java @@ -8,6 +8,7 @@ import jakarta.enterprise.inject.Instance; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; import org.jboss.jandex.DotName; import dev.langchain4j.agent.tool.Tool; @@ -62,6 +63,8 @@ public class DotNames { public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class); public static final DotName TOOL = DotName.createSimple(Tool.class); + public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class); + public static final DotName OUTPUT_GUARDRAIL_ACCUMULATOR = DotName.createSimple(OutputGuardrailAccumulator.class); /** diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index dcc9217d2..7964fbfa5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -92,6 +92,7 @@ public class ToolProcessor { private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class); public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class, Object.class); + private static final ResultHandle[] EMPTY_RESULT_HANDLE_ARRAY = new ResultHandle[0]; private static final Logger log = Logger.getLogger(ToolProcessor.class); @@ -136,7 +137,19 @@ public void handleTools( MethodInfo methodInfo = instance.target().asMethod(); ClassInfo classInfo = methodInfo.declaringClass(); - if (classInfo.isInterface() || Modifier.isAbstract(classInfo.flags())) { + boolean causeValidationError = false; + if (classInfo.isInterface()) { + + if (classInfo.hasAnnotation(LangChain4jDotNames.REGISTER_AI_SERVICES) || classInfo.hasAnnotation( + DotNames.REGISTER_REST_CLIENT)) { + // we allow tools on method of these interfaces because we know they will be beans + } else { + causeValidationError = true; + } + } else if (Modifier.isAbstract(classInfo.flags())) { + causeValidationError = true; + } + if (causeValidationError) { validation.produce( new ValidationPhaseBuildItem.ValidationErrorBuildItem(new IllegalStateException( "@Tool is only supported on non-abstract classes, all other usages are ignored. Offending method is '" @@ -409,16 +422,21 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu MethodDescriptor.ofMethod(implClassName, "invoke", Object.class, Object.class, Object[].class)); ResultHandle result; + ResultHandle[] targetMethodHandles = EMPTY_RESULT_HANDLE_ARRAY; if (methodInfo.parametersCount() > 0) { List argumentHandles = new ArrayList<>(methodInfo.parametersCount()); for (int i = 0; i < methodInfo.parametersCount(); i++) { argumentHandles.add(invokeMc.readArrayValue(invokeMc.getMethodParam(1), i)); } - ResultHandle[] targetMethodHandles = argumentHandles.toArray(new ResultHandle[0]); - result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0), + targetMethodHandles = argumentHandles.toArray(EMPTY_RESULT_HANDLE_ARRAY); + } + + if (methodInfo.declaringClass().isInterface()) { + result = invokeMc.invokeInterfaceMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0), targetMethodHandles); } else { - result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0)); + result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0), + targetMethodHandles); } boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID; diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java index 5898fe684..dc130ceaf 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java @@ -78,7 +78,7 @@ private void addEmbeddingStorePage(CardPageBuildItem card) { private void addAiServicesPage(CardPageBuildItem card, List aiServices) { List infos = new ArrayList<>(); for (DeclarativeAiServiceBuildItem aiService : aiServices) { - List tools = aiService.getToolDotNames().stream().map(dotName -> dotName.toString()).toList(); + List tools = aiService.getToolClassInfos().stream().map(ci -> ci.name().toString()).toList(); infos.add(new AiServiceInfo(aiService.getServiceClassInfo().name().toString(), tools)); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java new file mode 100644 index 000000000..0f200322b --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java @@ -0,0 +1,30 @@ +package io.quarkiverse.langchain4j.deployment.items; + +import jakarta.enterprise.util.AnnotationLiteral; + +import org.jboss.jandex.ClassInfo; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * Used to determine if a class containing a tool should be used along with a CDI qualifier + */ +public interface ToolQualifierProvider { + + boolean supports(ClassInfo classInfo); + + AnnotationLiteral qualifier(ClassInfo classInfo); + + final class BuildItem extends MultiBuildItem { + + private final ToolQualifierProvider provider; + + public BuildItem(ToolQualifierProvider provider) { + this.provider = provider; + } + + public ToolQualifierProvider getProvider() { + return provider; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index b99ca38e3..5307640be 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -11,6 +11,7 @@ import java.util.function.Supplier; import jakarta.enterprise.inject.Instance; +import jakarta.enterprise.util.AnnotationLiteral; import jakarta.enterprise.util.TypeLiteral; import dev.langchain4j.data.segment.TextSegment; @@ -148,12 +149,21 @@ public T apply(SyntheticCreationalContext creationalContext) { } } - List toolsClasses = info.toolsClassNames(); + Map> toolsClasses = info.toolsClassInfo(); if ((toolsClasses != null) && !toolsClasses.isEmpty()) { List tools = new ArrayList<>(toolsClasses.size()); - for (String toolClass : toolsClasses) { - Object tool = creationalContext.getInjectedReference( - Thread.currentThread().getContextClassLoader().loadClass(toolClass)); + for (var entry : toolsClasses.entrySet()) { + AnnotationLiteral qualifier = entry.getValue(); + Object tool; + if (qualifier != null) { + tool = creationalContext.getInjectedReference( + Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()), + qualifier); + } else { + tool = creationalContext.getInjectedReference( + Thread.currentThread().getContextClassLoader().loadClass(entry.getKey())); + } + tools.add(tool); } quarkusAiServices.tools(tools); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index 38c82b604..249687897 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -1,12 +1,14 @@ package io.quarkiverse.langchain4j.runtime.aiservice; -import java.util.List; +import java.util.Map; + +import jakarta.enterprise.util.AnnotationLiteral; public record DeclarativeAiServiceCreateInfo( String serviceClassName, String languageModelSupplierClassName, String streamingChatLanguageModelSupplierClassName, - List toolsClassNames, + Map> toolsClassInfo, String toolProviderSupplier, String chatMemoryProviderSupplierClassName, String retrieverClassName, diff --git a/model-providers/openai/openai-vanilla/deployment/pom.xml b/model-providers/openai/openai-vanilla/deployment/pom.xml index b3ef13f17..5dbb1a4e3 100644 --- a/model-providers/openai/openai-vanilla/deployment/pom.xml +++ b/model-providers/openai/openai-vanilla/deployment/pom.xml @@ -46,6 +46,11 @@ quarkus-smallrye-fault-tolerance test + + io.quarkus + quarkus-rest + test + io.smallrye.certs smallrye-certificate-generator-junit5 diff --git a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java new file mode 100644 index 000000000..311a4b9c2 --- /dev/null +++ b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java @@ -0,0 +1,219 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientResponseContext; +import jakarta.ws.rs.client.ClientResponseFilter; +import jakarta.ws.rs.ext.Provider; + +import org.eclipse.microprofile.config.inject.ConfigProperty; +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.stubbing.Scenario; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +public class RestClientToolTest extends OpenAiBaseTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", + WiremockAware.wiremockUrlForConfig("/v1")) + .overrideConfigKey("quarkus.rest-client.rest-calculator.url", "http://localhost:${quarkus.http.test-port:8081}"); + + private static final String scenario = "tools"; + private static final String secondState = "second"; + + @BeforeEach + void setUp() { + wiremock().resetMappings(); + wiremock().resetRequests(); + } + + @Inject + Bot bot; + + @Test + @ActivateRequestContext + void should_execute_tool_then_answer() throws IOException { + var firstResponse = """ + { + "id": "chatcmpl-8D88Dag1gAKnOPP9Ed4bos7vSpaNz", + "object": "chat.completion", + "created": 1698140213, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "function_call": { + "name": "squareRoot", + "arguments": "{\\n \\"number\\": 485906798473894056\\n}" + } + }, + "finish_reason": "function_call" + } + ], + "usage": { + "prompt_tokens": 65, + "completion_tokens": 20, + "total_tokens": 85 + } + } + """; + + var secondResponse = """ + { + "id": "chatcmpl-8D88FIAUWSpwLaShFr0w8G1SWuVdl", + "object": "chat.completion", + "created": 1698140215, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The square root of 485,906,798,473,894,056 in scientific notation is approximately 6.97070153193991E8." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 102, + "completion_tokens": 33, + "total_tokens": 135 + } + } + """; + + wiremock().register( + post(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer whatever")) + .inScenario(scenario) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(firstResponse))); + wiremock().register( + post(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer whatever")) + .inScenario(scenario) + .whenScenarioStateIs(secondState) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(secondResponse))); + + wiremock().setSingleScenarioState(scenario, Scenario.STARTED); + + String userMessage = "What is the square root of 485906798473894056 in scientific notation?"; + + String answer = bot.chat(userMessage); + + assertThat(answer).isEqualTo( + "The square root of 485,906,798,473,894,056 in scientific notation is approximately 6.97070153193991E8."); + + assertThat(wiremock().getServeEvents()).hasSize(2); + + Map firstApiRequest = getRequestAsMap(getRequestBody(wiremock().getServeEvents().get(1))); + assertSingleRequestMessage(firstApiRequest, + "What is the square root of 485906798473894056 in scientific notation?"); + assertSingleFunction(firstApiRequest, "squareRoot"); + Map secondApiRequest = getRequestAsMap(getRequestBody(wiremock().getServeEvents().get(0))); + assertMultipleRequestMessage(secondApiRequest, + List.of( + new MessageContent("user", + "What is the square root of 485906798473894056 in scientific notation?"), + new MessageContent("assistant", null), + new MessageContent("function", "6.97070153193991E8"))); + } + + @RegisterAiService(tools = RestCalculator.class) + interface Bot { + + String chat(String message); + } + + @Singleton + public static class CalculatorAfter implements Runnable { + + private final Integer wiremockPort; + + public CalculatorAfter(@ConfigProperty(name = "quarkus.wiremock.devservices.port") Integer wiremockPort) { + this.wiremockPort = wiremockPort; + } + + @Override + public void run() { + WireMock wireMock = new WireMock(wiremockPort); + wireMock.setSingleScenarioState(scenario, secondState); + } + } + + @Path("calculator") + @RegisterRestClient(configKey = "rest-calculator") + @RegisterProvider(RestCalculator.ResponseFilter.class) + interface RestCalculator { + + @POST + @Tool("calculates the square root of the provided number") + @Consumes("text/plain") + @Produces("text/plain") + double squareRoot(double number); + + @Provider + class ResponseFilter implements ClientResponseFilter { + + @Inject + CalculatorAfter after; + + @Override + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) { + after.run(); + } + } + } + + @Path("calculator") + public static class CalculatorResource { + + @POST + @Consumes("text/plain") + @Produces("text/plain") + public double squareRoot(double number) { + return Math.sqrt(number); + } + } + +}