From 148c6d525e74bedd6425523e145187718ab22fd4 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Thu, 12 Dec 2024 18:40:28 +0200 Subject: [PATCH 1/2] Allow Rest Client and AI Service to be used as tools --- core/deployment/pom.xml | 5 + .../deployment/AiServicesProcessor.java | 62 ++++- .../DeclarativeAiServiceBuildItem.java | 10 +- .../langchain4j/deployment/DotNames.java | 3 + .../langchain4j/deployment/ToolProcessor.java | 26 ++- .../devui/LangChain4jDevUIProcessor.java | 2 +- .../items/ToolQualifierProvider.java | 30 +++ .../runtime/AiServicesRecorder.java | 18 +- .../DeclarativeAiServiceCreateInfo.java | 6 +- .../openai/openai-vanilla/deployment/pom.xml | 5 + .../aiservices/RestClientToolTest.java | 219 ++++++++++++++++++ 11 files changed, 360 insertions(+), 26 deletions(-) create mode 100644 core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java create mode 100644 model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml index ab47f8411..4333b498d 100644 --- a/core/deployment/pom.xml +++ b/core/deployment/pom.xml @@ -46,6 +46,11 @@ true + + org.eclipse.microprofile.rest.client + microprofile-rest-client-api + + io.quarkus quarkus-vertx-http-dev-ui-tests diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index b8fed660a..f595405a8 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -43,8 +43,10 @@ import jakarta.annotation.PreDestroy; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.spi.DeploymentException; +import jakarta.enterprise.util.AnnotationLiteral; import jakarta.inject.Inject; +import org.eclipse.microprofile.rest.client.inject.RestClient; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTarget; import org.jboss.jandex.AnnotationValue; @@ -77,6 +79,7 @@ import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; +import io.quarkiverse.langchain4j.deployment.items.ToolQualifierProvider; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; @@ -262,11 +265,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, chatModelNames.add(chatModelName); } - List toolDotNames = Collections.emptyList(); + List toolClassInfos = Collections.emptyList(); AnnotationValue toolsInstance = instance.value("tools"); if (toolsInstance != null) { - toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name) - .collect(Collectors.toList()); + toolClassInfos = Arrays.stream(toolsInstance.asClassArray()).map(t -> { + var ci = index.getClassByName(t.name()); + if (ci == null) { + throw new IllegalArgumentException("Cannot find class " + t.name() + + " in index. Please make sure it's a valid CDI bean known to Quarkus"); + } + return ci; + }) + .toList(); } // the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean @@ -397,7 +407,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, declarativeAiServiceClassInfo, chatLanguageModelSupplierClassDotName, streamingChatLanguageModelSupplierClassDotName, - toolDotNames, + toolClassInfos, chatMemoryProviderSupplierClassDotName, retrieverClassDotName, retrievalAugmentorSupplierClassName, @@ -476,11 +486,27 @@ private boolean isImageOrImageResultResult(Type returnType) { return false; } + @BuildStep + public void toolQualifiers(BuildProducer producer) { + producer.produce(new ToolQualifierProvider.BuildItem(new ToolQualifierProvider() { + @Override + public boolean supports(ClassInfo classInfo) { + return classInfo.hasAnnotation(DotNames.REGISTER_REST_CLIENT); + } + + @Override + public AnnotationLiteral qualifier(ClassInfo classInfo) { + return new RestClient.RestClientLiteral(); + } + })); + } + @BuildStep @Record(ExecutionTime.STATIC_INIT) public void handleDeclarativeServices(AiServicesRecorder recorder, List declarativeAiServiceItems, List selectedChatModelProvider, + List toolQualifierProviderItems, BuildProducer syntheticBeanProducer, BuildProducer unremovableProducer) { @@ -507,7 +533,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, ? bi.getStreamingChatLanguageModelSupplierClassDotName().toString() : null); - List toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList()); + List toolQualifierProviders = toolQualifierProviderItems.stream().map( + ToolQualifierProvider.BuildItem::getProvider).toList(); + Map> toolToQualifierMap = new HashMap<>(); + for (ClassInfo ci : bi.getToolClassInfos()) { + AnnotationLiteral qualifier = null; + for (ToolQualifierProvider provider : toolQualifierProviders) { + if (provider.supports(ci)) { + qualifier = provider.qualifier(ci); + break; + } + } + toolToQualifierMap.put(ci.name().toString(), qualifier); + } String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null ? bi.getToolProviderClassDotName().toString() @@ -597,7 +635,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, serviceClassName, chatLanguageModelSupplierClassName, streamingChatLanguageModelSupplierClassName, - toolClassNames, + toolToQualifierMap, toolProviderSupplierClassName, chatMemoryProviderSupplierClassName, retrieverClassName, retrievalAugmentorSupplierClassName, @@ -639,12 +677,16 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, needsChatModelBean = true; } - if (!toolClassNames.isEmpty()) { - for (String toolClassName : toolClassNames) { - DotName dotName = DotName.createSimple(toolClassName); + for (var entry : toolToQualifierMap.entrySet()) { + DotName dotName = DotName.createSimple(entry.getKey()); + AnnotationLiteral qualifier = entry.getValue(); + if (qualifier == null) { configurator.addInjectionPoint(ClassType.create(dotName)); - allToolNames.add(dotName); + } else { + configurator.addInjectionPoint(ClassType.create(dotName), + AnnotationInstance.builder(qualifier.annotationType()).build()); } + allToolNames.add(dotName); } if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index 6fd1fc997..03449e2f4 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -16,7 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final ClassInfo serviceClassInfo; private final DotName chatLanguageModelSupplierClassDotName; private final DotName streamingChatLanguageModelSupplierClassDotName; - private final List toolDotNames; + private final List toolClassInfos; private final DotName toolProviderClassDotName; private final DotName chatMemoryProviderSupplierClassDotName; @@ -37,7 +37,7 @@ public DeclarativeAiServiceBuildItem( ClassInfo serviceClassInfo, DotName chatLanguageModelSupplierClassDotName, DotName streamingChatLanguageModelSupplierClassDotName, - List toolDotNames, + List toolClassInfos, DotName chatMemoryProviderSupplierClassDotName, DotName retrieverClassDotName, DotName retrievalAugmentorSupplierClassDotName, @@ -55,7 +55,7 @@ public DeclarativeAiServiceBuildItem( this.serviceClassInfo = serviceClassInfo; this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName; this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName; - this.toolDotNames = toolDotNames; + this.toolClassInfos = toolClassInfos; this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName; this.retrieverClassDotName = retrieverClassDotName; this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName; @@ -84,8 +84,8 @@ public DotName getStreamingChatLanguageModelSupplierClassDotName() { return streamingChatLanguageModelSupplierClassDotName; } - public List getToolDotNames() { - return toolDotNames; + public List getToolClassInfos() { + return toolClassInfos; } public DotName getChatMemoryProviderSupplierClassDotName() { diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java index 49b09380d..af4591915 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java @@ -8,6 +8,7 @@ import jakarta.enterprise.inject.Instance; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; import org.jboss.jandex.DotName; import dev.langchain4j.agent.tool.Tool; @@ -62,6 +63,8 @@ public class DotNames { public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class); public static final DotName TOOL = DotName.createSimple(Tool.class); + public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class); + public static final DotName OUTPUT_GUARDRAIL_ACCUMULATOR = DotName.createSimple(OutputGuardrailAccumulator.class); /** diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index dcc9217d2..7964fbfa5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -92,6 +92,7 @@ public class ToolProcessor { private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class); public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class, Object.class); + private static final ResultHandle[] EMPTY_RESULT_HANDLE_ARRAY = new ResultHandle[0]; private static final Logger log = Logger.getLogger(ToolProcessor.class); @@ -136,7 +137,19 @@ public void handleTools( MethodInfo methodInfo = instance.target().asMethod(); ClassInfo classInfo = methodInfo.declaringClass(); - if (classInfo.isInterface() || Modifier.isAbstract(classInfo.flags())) { + boolean causeValidationError = false; + if (classInfo.isInterface()) { + + if (classInfo.hasAnnotation(LangChain4jDotNames.REGISTER_AI_SERVICES) || classInfo.hasAnnotation( + DotNames.REGISTER_REST_CLIENT)) { + // we allow tools on method of these interfaces because we know they will be beans + } else { + causeValidationError = true; + } + } else if (Modifier.isAbstract(classInfo.flags())) { + causeValidationError = true; + } + if (causeValidationError) { validation.produce( new ValidationPhaseBuildItem.ValidationErrorBuildItem(new IllegalStateException( "@Tool is only supported on non-abstract classes, all other usages are ignored. Offending method is '" @@ -409,16 +422,21 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu MethodDescriptor.ofMethod(implClassName, "invoke", Object.class, Object.class, Object[].class)); ResultHandle result; + ResultHandle[] targetMethodHandles = EMPTY_RESULT_HANDLE_ARRAY; if (methodInfo.parametersCount() > 0) { List argumentHandles = new ArrayList<>(methodInfo.parametersCount()); for (int i = 0; i < methodInfo.parametersCount(); i++) { argumentHandles.add(invokeMc.readArrayValue(invokeMc.getMethodParam(1), i)); } - ResultHandle[] targetMethodHandles = argumentHandles.toArray(new ResultHandle[0]); - result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0), + targetMethodHandles = argumentHandles.toArray(EMPTY_RESULT_HANDLE_ARRAY); + } + + if (methodInfo.declaringClass().isInterface()) { + result = invokeMc.invokeInterfaceMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0), targetMethodHandles); } else { - result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0)); + result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0), + targetMethodHandles); } boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID; diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java index 5898fe684..dc130ceaf 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java @@ -78,7 +78,7 @@ private void addEmbeddingStorePage(CardPageBuildItem card) { private void addAiServicesPage(CardPageBuildItem card, List aiServices) { List infos = new ArrayList<>(); for (DeclarativeAiServiceBuildItem aiService : aiServices) { - List tools = aiService.getToolDotNames().stream().map(dotName -> dotName.toString()).toList(); + List tools = aiService.getToolClassInfos().stream().map(ci -> ci.name().toString()).toList(); infos.add(new AiServiceInfo(aiService.getServiceClassInfo().name().toString(), tools)); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java new file mode 100644 index 000000000..0f200322b --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolQualifierProvider.java @@ -0,0 +1,30 @@ +package io.quarkiverse.langchain4j.deployment.items; + +import jakarta.enterprise.util.AnnotationLiteral; + +import org.jboss.jandex.ClassInfo; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * Used to determine if a class containing a tool should be used along with a CDI qualifier + */ +public interface ToolQualifierProvider { + + boolean supports(ClassInfo classInfo); + + AnnotationLiteral qualifier(ClassInfo classInfo); + + final class BuildItem extends MultiBuildItem { + + private final ToolQualifierProvider provider; + + public BuildItem(ToolQualifierProvider provider) { + this.provider = provider; + } + + public ToolQualifierProvider getProvider() { + return provider; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index b99ca38e3..5307640be 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -11,6 +11,7 @@ import java.util.function.Supplier; import jakarta.enterprise.inject.Instance; +import jakarta.enterprise.util.AnnotationLiteral; import jakarta.enterprise.util.TypeLiteral; import dev.langchain4j.data.segment.TextSegment; @@ -148,12 +149,21 @@ public T apply(SyntheticCreationalContext creationalContext) { } } - List toolsClasses = info.toolsClassNames(); + Map> toolsClasses = info.toolsClassInfo(); if ((toolsClasses != null) && !toolsClasses.isEmpty()) { List tools = new ArrayList<>(toolsClasses.size()); - for (String toolClass : toolsClasses) { - Object tool = creationalContext.getInjectedReference( - Thread.currentThread().getContextClassLoader().loadClass(toolClass)); + for (var entry : toolsClasses.entrySet()) { + AnnotationLiteral qualifier = entry.getValue(); + Object tool; + if (qualifier != null) { + tool = creationalContext.getInjectedReference( + Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()), + qualifier); + } else { + tool = creationalContext.getInjectedReference( + Thread.currentThread().getContextClassLoader().loadClass(entry.getKey())); + } + tools.add(tool); } quarkusAiServices.tools(tools); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index 38c82b604..249687897 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -1,12 +1,14 @@ package io.quarkiverse.langchain4j.runtime.aiservice; -import java.util.List; +import java.util.Map; + +import jakarta.enterprise.util.AnnotationLiteral; public record DeclarativeAiServiceCreateInfo( String serviceClassName, String languageModelSupplierClassName, String streamingChatLanguageModelSupplierClassName, - List toolsClassNames, + Map> toolsClassInfo, String toolProviderSupplier, String chatMemoryProviderSupplierClassName, String retrieverClassName, diff --git a/model-providers/openai/openai-vanilla/deployment/pom.xml b/model-providers/openai/openai-vanilla/deployment/pom.xml index b3ef13f17..5dbb1a4e3 100644 --- a/model-providers/openai/openai-vanilla/deployment/pom.xml +++ b/model-providers/openai/openai-vanilla/deployment/pom.xml @@ -46,6 +46,11 @@ quarkus-smallrye-fault-tolerance test + + io.quarkus + quarkus-rest + test + io.smallrye.certs smallrye-certificate-generator-junit5 diff --git a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java new file mode 100644 index 000000000..311a4b9c2 --- /dev/null +++ b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RestClientToolTest.java @@ -0,0 +1,219 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientResponseContext; +import jakarta.ws.rs.client.ClientResponseFilter; +import jakarta.ws.rs.ext.Provider; + +import org.eclipse.microprofile.config.inject.ConfigProperty; +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.stubbing.Scenario; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +public class RestClientToolTest extends OpenAiBaseTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", + WiremockAware.wiremockUrlForConfig("/v1")) + .overrideConfigKey("quarkus.rest-client.rest-calculator.url", "http://localhost:${quarkus.http.test-port:8081}"); + + private static final String scenario = "tools"; + private static final String secondState = "second"; + + @BeforeEach + void setUp() { + wiremock().resetMappings(); + wiremock().resetRequests(); + } + + @Inject + Bot bot; + + @Test + @ActivateRequestContext + void should_execute_tool_then_answer() throws IOException { + var firstResponse = """ + { + "id": "chatcmpl-8D88Dag1gAKnOPP9Ed4bos7vSpaNz", + "object": "chat.completion", + "created": 1698140213, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "function_call": { + "name": "squareRoot", + "arguments": "{\\n \\"number\\": 485906798473894056\\n}" + } + }, + "finish_reason": "function_call" + } + ], + "usage": { + "prompt_tokens": 65, + "completion_tokens": 20, + "total_tokens": 85 + } + } + """; + + var secondResponse = """ + { + "id": "chatcmpl-8D88FIAUWSpwLaShFr0w8G1SWuVdl", + "object": "chat.completion", + "created": 1698140215, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The square root of 485,906,798,473,894,056 in scientific notation is approximately 6.97070153193991E8." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 102, + "completion_tokens": 33, + "total_tokens": 135 + } + } + """; + + wiremock().register( + post(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer whatever")) + .inScenario(scenario) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(firstResponse))); + wiremock().register( + post(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer whatever")) + .inScenario(scenario) + .whenScenarioStateIs(secondState) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(secondResponse))); + + wiremock().setSingleScenarioState(scenario, Scenario.STARTED); + + String userMessage = "What is the square root of 485906798473894056 in scientific notation?"; + + String answer = bot.chat(userMessage); + + assertThat(answer).isEqualTo( + "The square root of 485,906,798,473,894,056 in scientific notation is approximately 6.97070153193991E8."); + + assertThat(wiremock().getServeEvents()).hasSize(2); + + Map firstApiRequest = getRequestAsMap(getRequestBody(wiremock().getServeEvents().get(1))); + assertSingleRequestMessage(firstApiRequest, + "What is the square root of 485906798473894056 in scientific notation?"); + assertSingleFunction(firstApiRequest, "squareRoot"); + Map secondApiRequest = getRequestAsMap(getRequestBody(wiremock().getServeEvents().get(0))); + assertMultipleRequestMessage(secondApiRequest, + List.of( + new MessageContent("user", + "What is the square root of 485906798473894056 in scientific notation?"), + new MessageContent("assistant", null), + new MessageContent("function", "6.97070153193991E8"))); + } + + @RegisterAiService(tools = RestCalculator.class) + interface Bot { + + String chat(String message); + } + + @Singleton + public static class CalculatorAfter implements Runnable { + + private final Integer wiremockPort; + + public CalculatorAfter(@ConfigProperty(name = "quarkus.wiremock.devservices.port") Integer wiremockPort) { + this.wiremockPort = wiremockPort; + } + + @Override + public void run() { + WireMock wireMock = new WireMock(wiremockPort); + wireMock.setSingleScenarioState(scenario, secondState); + } + } + + @Path("calculator") + @RegisterRestClient(configKey = "rest-calculator") + @RegisterProvider(RestCalculator.ResponseFilter.class) + interface RestCalculator { + + @POST + @Tool("calculates the square root of the provided number") + @Consumes("text/plain") + @Produces("text/plain") + double squareRoot(double number); + + @Provider + class ResponseFilter implements ClientResponseFilter { + + @Inject + CalculatorAfter after; + + @Override + public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext) { + after.run(); + } + } + } + + @Path("calculator") + public static class CalculatorResource { + + @POST + @Consumes("text/plain") + @Produces("text/plain") + public double squareRoot(double number) { + return Math.sqrt(number); + } + } + +} From 989ae312e046f8a97a7e06e55878e6a321f1cf27 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Thu, 12 Dec 2024 17:13:27 +0200 Subject: [PATCH 2/2] Add weather agent sample --- samples/pom.xml | 1 + samples/weather-agent/README.md | 63 ++++++++ samples/weather-agent/pom.xml | 135 ++++++++++++++++++ .../weather/agent/CityExtractorAgent.java | 21 +++ .../weather/agent/WeatherForecastAgent.java | 24 ++++ .../weather/agent/WeatherResource.java | 27 ++++ .../weather/agent/geo/GeoCodingService.java | 22 +++ .../weather/agent/geo/GeoResult.java | 4 + .../weather/agent/geo/GeoResults.java | 11 ++ .../weather/agent/weather/Daily.java | 29 ++++ .../weather/agent/weather/DailyUnits.java | 7 + .../agent/weather/DailyWeatherData.java | 23 +++ .../agent/weather/WeatherForecast.java | 6 + .../agent/weather/WeatherForecastService.java | 27 ++++ .../weather/agent/weather/WmoCode.java | 31 ++++ .../src/main/resources/application.properties | 5 + 16 files changed, 436 insertions(+) create mode 100644 samples/weather-agent/README.md create mode 100644 samples/weather-agent/pom.xml create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/CityExtractorAgent.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherForecastAgent.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherResource.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoCodingService.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResult.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResults.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/Daily.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyUnits.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyWeatherData.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecast.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecastService.java create mode 100644 samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WmoCode.java create mode 100644 samples/weather-agent/src/main/resources/application.properties diff --git a/samples/pom.xml b/samples/pom.xml index c2acc9c41..b04dd2cc7 100644 --- a/samples/pom.xml +++ b/samples/pom.xml @@ -22,6 +22,7 @@ secure-poem-multiple-models secure-sql-chatbot sql-chatbot + weather-agent diff --git a/samples/weather-agent/README.md b/samples/weather-agent/README.md new file mode 100644 index 000000000..407545337 --- /dev/null +++ b/samples/weather-agent/README.md @@ -0,0 +1,63 @@ +# Chatbot example + +This example demonstrates how to create an AI agent using Quarkus LangChain4j. + +## Running the example + +A prerequisite to running this example is to provide your OpenAI API key. + +``` +export QUARKUS_LANGCHAIN4J_OPENAI_API_KEY= +``` + +Then, simply run the project in Dev mode: + +``` +mvn quarkus:dev +``` + +## Using the example + +Execute: + +``` +curl http://localhost:8080/weather?city=Athens +``` + +and you should get a response a like so: + +``` +The weather in Athens today is mostly cloudy, with a maximum temperature of 15.6°C and a minimum of 7.4°C. There is no expected precipitation and wind speeds can reach up to 8.1 km/h +``` + +## Using other model providers + +### Compatible OpenAI serving infrastructure + +Add `quarkus.langchain4j.openai.base-url=http://yourerver` to `application.properties`. + +In this case, `quarkus.langchain4j.openai.api-key` is generally not needed. + +### Ollama + + +Replace: + +```xml + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai + ${quarkus-langchain4j.version} + +``` + +with + +```xml + + io.quarkiverse.langchain4j + quarkus-langchain4j-ollama + ${quarkus-langchain4j.version} + +``` + diff --git a/samples/weather-agent/pom.xml b/samples/weather-agent/pom.xml new file mode 100644 index 000000000..4ebd9b3a6 --- /dev/null +++ b/samples/weather-agent/pom.xml @@ -0,0 +1,135 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-sample-weather-agent + Quarkus LangChain4j - Sample - Weather Agent + 1.0-SNAPSHOT + + + 3.13.0 + true + 17 + UTF-8 + UTF-8 + quarkus-bom + io.quarkus + 3.15.1 + true + 3.2.5 + 999-SNAPSHOT + + + + + + ${quarkus.platform.group-id} + ${quarkus.platform.artifact-id} + ${quarkus.platform.version} + pom + import + + + + + + + io.quarkus + quarkus-rest-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai + ${quarkus-langchain4j.version} + + + io.quarkus + quarkus-cache + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai-deployment + ${quarkus-langchain4j.version} + test + pom + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + ${quarkus.platform.version} + + + + build + + + + + + maven-compiler-plugin + ${compiler-plugin.version} + + + maven-surefire-plugin + 3.5.1 + + + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + native + + + native + + + + + + maven-failsafe-plugin + 3.5.1 + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + native + + + + + + diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/CityExtractorAgent.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/CityExtractorAgent.java new file mode 100644 index 000000000..ee911a066 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/CityExtractorAgent.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.weather.agent; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) +public interface CityExtractorAgent { + + @UserMessage(""" + You are given one question and you have to extract city name from it + Only reply the city name if it exists or reply 'unknown_city' if there is no city name in question + + Here is the question: {question} + """) + @Tool("Extracts the city from a question") + String extractCity(String question); + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherForecastAgent.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherForecastAgent.java new file mode 100644 index 000000000..e050e9372 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherForecastAgent.java @@ -0,0 +1,24 @@ +package io.quarkiverse.langchain4j.weather.agent; + +import dev.langchain4j.service.SystemMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.weather.agent.geo.GeoCodingService; +import io.quarkiverse.langchain4j.weather.agent.weather.WeatherForecastService; + +@RegisterAiService(tools = { CityExtractorAgent.class, WeatherForecastService.class, GeoCodingService.class}) +public interface WeatherForecastAgent { + + @SystemMessage(""" + You are a meteorologist, and you need to answer questions asked by the user about weather using at most 3 lines. + + The weather information is a JSON object and has the following fields: + + maxTemperature is the maximum temperature of the day in Celsius degrees + minTemperature is the minimum temperature of the day in Celsius degrees + precipitation is the amount of water in mm + windSpeed is the speed of wind in kilometers per hour + weather is the overall weather. + """) + String chat(String query); + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherResource.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherResource.java new file mode 100644 index 000000000..95ece59af --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/WeatherResource.java @@ -0,0 +1,27 @@ +package io.quarkiverse.langchain4j.weather.agent; + +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; +import org.jboss.resteasy.reactive.RestQuery; + + +@Path("/weather") +public class WeatherResource { + + private final WeatherForecastAgent agent; + + public WeatherResource(WeatherForecastAgent agent) { + this.agent = agent; + } + + @GET + @Produces(MediaType.TEXT_PLAIN) + public String getWeather(@RestQuery @DefaultValue("Manilla") String city) { + return agent.chat(String.format("What is the weather in %s ?", city)); + } + + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoCodingService.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoCodingService.java new file mode 100644 index 000000000..6e23ce893 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoCodingService.java @@ -0,0 +1,22 @@ +package io.quarkiverse.langchain4j.weather.agent.geo; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkus.cache.CacheResult; +import io.quarkus.rest.client.reactive.ClientQueryParam; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import org.jboss.resteasy.reactive.RestQuery; + +@RegisterRestClient(configKey = "geocoding") +@Path("/v1") +public interface GeoCodingService { + + @GET + @Path("/search") + @CacheResult(cacheName = "geo-results") + @ClientQueryParam(name = "count", value = "1") + @Tool("Finds the latitude and longitude of a given city") + GeoResults search(@RestQuery String name); + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResult.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResult.java new file mode 100644 index 000000000..404eb810f --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResult.java @@ -0,0 +1,4 @@ +package io.quarkiverse.langchain4j.weather.agent.geo; + +public record GeoResult(double latitude, double longitude) { +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResults.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResults.java new file mode 100644 index 000000000..03322a5c6 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/geo/GeoResults.java @@ -0,0 +1,11 @@ +package io.quarkiverse.langchain4j.weather.agent.geo; + +import java.util.List; + +public record GeoResults(List results) { + + public GeoResult getFirst() { + return results.get(0); + } + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/Daily.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/Daily.java new file mode 100644 index 000000000..cf16c72e3 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/Daily.java @@ -0,0 +1,29 @@ +package io.quarkiverse.langchain4j.weather.agent.weather; + +import java.util.Arrays; + +public record Daily(double[] temperature_2m_max, + double[] temperature_2m_min, + double[] precipitation_sum, + double[] wind_speed_10m_max, + int[] weather_code) { + + public DailyWeatherData getFirstDay() { + return new DailyWeatherData(temperature_2m_max[0], + temperature_2m_min[0], + precipitation_sum[0], + wind_speed_10m_max[0], + weather_code[0]); + } + + @Override + public String toString() { + return "Daily{" + "temperature_2m_max=" + Arrays.toString(temperature_2m_max) + + ", temperature_2m_min=" + Arrays.toString(temperature_2m_min) + + ", precipitation_sum=" + Arrays.toString(precipitation_sum) + + ", wind_speed_10m_max=" + Arrays.toString(wind_speed_10m_max) + + ", weather_code=" + Arrays.toString(weather_code) + + '}'; + } + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyUnits.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyUnits.java new file mode 100644 index 000000000..f6d9799f0 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyUnits.java @@ -0,0 +1,7 @@ +package io.quarkiverse.langchain4j.weather.agent.weather; + +public record DailyUnits(String time, + String temperature_2m_max, + String precipitation_sum, + String wind_speed_10m_max) { +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyWeatherData.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyWeatherData.java new file mode 100644 index 000000000..25f46a4e6 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/DailyWeatherData.java @@ -0,0 +1,23 @@ +package io.quarkiverse.langchain4j.weather.agent.weather; + +import io.vertx.core.json.JsonObject; + +public record DailyWeatherData(double temperature_2m_max, + double temperature_2m_min, + double precipitation_sum, + double wind_speed_10m_max, + int weather_code) { + + + public JsonObject toJson() { + JsonObject json = new JsonObject(); + json.put("maxTemperature", temperature_2m_max()); + json.put("minTemperature", temperature_2m_min()); + json.put("precipitation", precipitation_sum()); + json.put("windSpeed", wind_speed_10m_max()); + json.put("weather", WmoCode.translate(weather_code())); + + return json; + } + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecast.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecast.java new file mode 100644 index 000000000..fbae14887 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecast.java @@ -0,0 +1,6 @@ +package io.quarkiverse.langchain4j.weather.agent.weather; + +public record WeatherForecast(DailyUnits daily_units, Daily daily) { + + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecastService.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecastService.java new file mode 100644 index 000000000..4827f18b0 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WeatherForecastService.java @@ -0,0 +1,27 @@ +package io.quarkiverse.langchain4j.weather.agent.weather; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkus.rest.client.reactive.ClientQueryParam; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import org.jboss.resteasy.reactive.RestQuery; + +@RegisterRestClient(configKey = "openmeteo") +@Path("/v1") +public interface WeatherForecastService { + + @GET + @Path("/forecast") + @Tool("Forecasts the weather for the given latitude and longitude") + @ClientQueryParam(name = "forecast_days", value = "1") + @ClientQueryParam(name = "daily", value = { + "temperature_2m_max", + "temperature_2m_min", + "precipitation_sum", + "wind_speed_10m_max", + "weather_code" + }) + WeatherForecast forecast(@RestQuery double latitude, @RestQuery double longitude); + +} diff --git a/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WmoCode.java b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WmoCode.java new file mode 100644 index 000000000..195b0bcb8 --- /dev/null +++ b/samples/weather-agent/src/main/java/io/quarkiverse/langchain4j/weather/agent/weather/WmoCode.java @@ -0,0 +1,31 @@ +package io.quarkiverse.langchain4j.weather.agent.weather; + +import java.util.Arrays; + +public enum WmoCode { + + CLEAR_SKY(0), MAINLY_CLEAR(1), PARTLY_CLOUDY(2), OVERCAST(3), + FOG(45), DEPOSITING_RIME_FOG(46), DRIZZLE_LIGHT(51), DRIZZLE_MEDIUM(53), + DRIZZLE_DENSE(55), FREEZING_DRIZZLE_LIGHT(56), FREEZING_DRIZZLE_DENSE(57), + RAIN_SLIGHT(61), RAIN_MODERATE(63), RAIN_HEAVY(65), FREEZING_RAIN_LIGHT(66), FREEZING_RAIN_HEAVY(67), + SNOW_FALL_SLIGHT(71), SNOW_FALL_MODERATE(73), SNOW_FALL_HEAVY(75), SNOW_GRAINS(77), + RAIN_SHOWERS_SLIGHT(80), RAIN_SHOWERS_MODERATE(81), RAIN_SHOWERS_VIOLENT(82), + SNOW_SHOWERS_SLIGHT(85), SNOW_SHOWERS_HEAVY(86), THUNDERSTORM(95), + THUNDERSTORM_SLIGHT_HAIL(96), THUNDERSTORM_HEAVY_HAIL(99); + + final int code; + + WmoCode(int code) { + this.code = code; + } + + public static WmoCode translate(int code) { + WmoCode[] values = WmoCode.values(); + + return Arrays.stream(values) + .filter(wmoCode -> code == wmoCode.code) + .findFirst() + .orElse(null); + } + +} diff --git a/samples/weather-agent/src/main/resources/application.properties b/samples/weather-agent/src/main/resources/application.properties new file mode 100644 index 000000000..d25dc6921 --- /dev/null +++ b/samples/weather-agent/src/main/resources/application.properties @@ -0,0 +1,5 @@ +quarkus.langchain4j.log-requests=true +quarkus.langchain4j.log-responses=true + +quarkus.rest-client.geocoding.url=https://geocoding-api.open-meteo.com +quarkus.rest-client.openmeteo.url=https://api.open-meteo.com