Skip to content

Commit

Permalink
Merge pull request #1157 from quarkiverse/weather-agent
Browse files Browse the repository at this point in the history
Allow Rest Client and AI Service to be used as tools
  • Loading branch information
geoand authored Dec 13, 2024
2 parents 742b816 + 989ae31 commit 4fdab1d
Show file tree
Hide file tree
Showing 27 changed files with 796 additions and 26 deletions.
5 changes: 5 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
<optional>true</optional> <!-- conditional dependency -->
</dependency>

<dependency>
<groupId>org.eclipse.microprofile.rest.client</groupId>
<artifactId>microprofile-rest-client-api</artifactId>
</dependency>

<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx-http-dev-ui-tests</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.Dependent;
import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.enterprise.util.AnnotationLiteral;
import jakarta.inject.Inject;

import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
Expand Down Expand Up @@ -77,6 +79,7 @@
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ToolQualifierProvider;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand Down Expand Up @@ -262,11 +265,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
chatModelNames.add(chatModelName);
}

List<DotName> toolDotNames = Collections.emptyList();
List<ClassInfo> toolClassInfos = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
toolClassInfos = Arrays.stream(toolsInstance.asClassArray()).map(t -> {
var ci = index.getClassByName(t.name());
if (ci == null) {
throw new IllegalArgumentException("Cannot find class " + t.name()
+ " in index. Please make sure it's a valid CDI bean known to Quarkus");
}
return ci;
})
.toList();
}

// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
Expand Down Expand Up @@ -397,7 +407,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
streamingChatLanguageModelSupplierClassDotName,
toolDotNames,
toolClassInfos,
chatMemoryProviderSupplierClassDotName,
retrieverClassDotName,
retrievalAugmentorSupplierClassName,
Expand Down Expand Up @@ -476,11 +486,27 @@ private boolean isImageOrImageResultResult(Type returnType) {
return false;
}

@BuildStep
public void toolQualifiers(BuildProducer<ToolQualifierProvider.BuildItem> producer) {
producer.produce(new ToolQualifierProvider.BuildItem(new ToolQualifierProvider() {
@Override
public boolean supports(ClassInfo classInfo) {
return classInfo.hasAnnotation(DotNames.REGISTER_REST_CLIENT);
}

@Override
public AnnotationLiteral<?> qualifier(ClassInfo classInfo) {
return new RestClient.RestClientLiteral();
}
}));
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleDeclarativeServices(AiServicesRecorder recorder,
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
List<SelectedChatModelProviderBuildItem> selectedChatModelProvider,
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems,
BuildProducer<SyntheticBeanBuildItem> syntheticBeanProducer,
BuildProducer<UnremovableBeanBuildItem> unremovableProducer) {

Expand All @@ -507,7 +533,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getStreamingChatLanguageModelSupplierClassDotName().toString()
: null);

List<String> toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
List<ToolQualifierProvider> toolQualifierProviders = toolQualifierProviderItems.stream().map(
ToolQualifierProvider.BuildItem::getProvider).toList();
Map<String, AnnotationLiteral<?>> toolToQualifierMap = new HashMap<>();
for (ClassInfo ci : bi.getToolClassInfos()) {
AnnotationLiteral<?> qualifier = null;
for (ToolQualifierProvider provider : toolQualifierProviders) {
if (provider.supports(ci)) {
qualifier = provider.qualifier(ci);
break;
}
}
toolToQualifierMap.put(ci.name().toString(), qualifier);
}

String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null
? bi.getToolProviderClassDotName().toString()
Expand Down Expand Up @@ -597,7 +635,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
serviceClassName,
chatLanguageModelSupplierClassName,
streamingChatLanguageModelSupplierClassName,
toolClassNames,
toolToQualifierMap,
toolProviderSupplierClassName,
chatMemoryProviderSupplierClassName, retrieverClassName,
retrievalAugmentorSupplierClassName,
Expand Down Expand Up @@ -639,12 +677,16 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsChatModelBean = true;
}

if (!toolClassNames.isEmpty()) {
for (String toolClassName : toolClassNames) {
DotName dotName = DotName.createSimple(toolClassName);
for (var entry : toolToQualifierMap.entrySet()) {
DotName dotName = DotName.createSimple(entry.getKey());
AnnotationLiteral<?> qualifier = entry.getValue();
if (qualifier == null) {
configurator.addInjectionPoint(ClassType.create(dotName));
allToolNames.add(dotName);
} else {
configurator.addInjectionPoint(ClassType.create(dotName),
AnnotationInstance.builder(qualifier.annotationType()).build());
}
allToolNames.add(dotName);
}

if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final ClassInfo serviceClassInfo;
private final DotName chatLanguageModelSupplierClassDotName;
private final DotName streamingChatLanguageModelSupplierClassDotName;
private final List<DotName> toolDotNames;
private final List<ClassInfo> toolClassInfos;
private final DotName toolProviderClassDotName;

private final DotName chatMemoryProviderSupplierClassDotName;
Expand All @@ -37,7 +37,7 @@ public DeclarativeAiServiceBuildItem(
ClassInfo serviceClassInfo,
DotName chatLanguageModelSupplierClassDotName,
DotName streamingChatLanguageModelSupplierClassDotName,
List<DotName> toolDotNames,
List<ClassInfo> toolClassInfos,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverClassDotName,
DotName retrievalAugmentorSupplierClassDotName,
Expand All @@ -55,7 +55,7 @@ public DeclarativeAiServiceBuildItem(
this.serviceClassInfo = serviceClassInfo;
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.toolClassInfos = toolClassInfos;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverClassDotName = retrieverClassDotName;
this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName;
Expand Down Expand Up @@ -84,8 +84,8 @@ public DotName getStreamingChatLanguageModelSupplierClassDotName() {
return streamingChatLanguageModelSupplierClassDotName;
}

public List<DotName> getToolDotNames() {
return toolDotNames;
public List<ClassInfo> getToolClassInfos() {
return toolClassInfos;
}

public DotName getChatMemoryProviderSupplierClassDotName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import jakarta.enterprise.inject.Instance;

import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.jandex.DotName;

import dev.langchain4j.agent.tool.Tool;
Expand Down Expand Up @@ -62,6 +63,8 @@ public class DotNames {
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
public static final DotName TOOL = DotName.createSimple(Tool.class);

public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class);

public static final DotName OUTPUT_GUARDRAIL_ACCUMULATOR = DotName.createSimple(OutputGuardrailAccumulator.class);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public class ToolProcessor {
private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class);
public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class,
Object.class);
private static final ResultHandle[] EMPTY_RESULT_HANDLE_ARRAY = new ResultHandle[0];

private static final Logger log = Logger.getLogger(ToolProcessor.class);

Expand Down Expand Up @@ -136,7 +137,19 @@ public void handleTools(

MethodInfo methodInfo = instance.target().asMethod();
ClassInfo classInfo = methodInfo.declaringClass();
if (classInfo.isInterface() || Modifier.isAbstract(classInfo.flags())) {
boolean causeValidationError = false;
if (classInfo.isInterface()) {

if (classInfo.hasAnnotation(LangChain4jDotNames.REGISTER_AI_SERVICES) || classInfo.hasAnnotation(
DotNames.REGISTER_REST_CLIENT)) {
// we allow tools on method of these interfaces because we know they will be beans
} else {
causeValidationError = true;
}
} else if (Modifier.isAbstract(classInfo.flags())) {
causeValidationError = true;
}
if (causeValidationError) {
validation.produce(
new ValidationPhaseBuildItem.ValidationErrorBuildItem(new IllegalStateException(
"@Tool is only supported on non-abstract classes, all other usages are ignored. Offending method is '"
Expand Down Expand Up @@ -409,16 +422,21 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu
MethodDescriptor.ofMethod(implClassName, "invoke", Object.class, Object.class, Object[].class));

ResultHandle result;
ResultHandle[] targetMethodHandles = EMPTY_RESULT_HANDLE_ARRAY;
if (methodInfo.parametersCount() > 0) {
List<ResultHandle> argumentHandles = new ArrayList<>(methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
argumentHandles.add(invokeMc.readArrayValue(invokeMc.getMethodParam(1), i));
}
ResultHandle[] targetMethodHandles = argumentHandles.toArray(new ResultHandle[0]);
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
targetMethodHandles = argumentHandles.toArray(EMPTY_RESULT_HANDLE_ARRAY);
}

if (methodInfo.declaringClass().isInterface()) {
result = invokeMc.invokeInterfaceMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
targetMethodHandles);
} else {
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0));
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
targetMethodHandles);
}

boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private void addEmbeddingStorePage(CardPageBuildItem card) {
private void addAiServicesPage(CardPageBuildItem card, List<DeclarativeAiServiceBuildItem> aiServices) {
List<AiServiceInfo> infos = new ArrayList<>();
for (DeclarativeAiServiceBuildItem aiService : aiServices) {
List<String> tools = aiService.getToolDotNames().stream().map(dotName -> dotName.toString()).toList();
List<String> tools = aiService.getToolClassInfos().stream().map(ci -> ci.name().toString()).toList();
infos.add(new AiServiceInfo(aiService.getServiceClassInfo().name().toString(), tools));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.quarkiverse.langchain4j.deployment.items;

import jakarta.enterprise.util.AnnotationLiteral;

import org.jboss.jandex.ClassInfo;

import io.quarkus.builder.item.MultiBuildItem;

/**
* Used to determine if a class containing a tool should be used along with a CDI qualifier
*/
public interface ToolQualifierProvider {

boolean supports(ClassInfo classInfo);

AnnotationLiteral<?> qualifier(ClassInfo classInfo);

final class BuildItem extends MultiBuildItem {

private final ToolQualifierProvider provider;

public BuildItem(ToolQualifierProvider provider) {
this.provider = provider;
}

public ToolQualifierProvider getProvider() {
return provider;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.util.AnnotationLiteral;
import jakarta.enterprise.util.TypeLiteral;

import dev.langchain4j.data.segment.TextSegment;
Expand Down Expand Up @@ -148,12 +149,21 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

List<String> toolsClasses = info.toolsClassNames();
Map<String, AnnotationLiteral<?>> toolsClasses = info.toolsClassInfo();
if ((toolsClasses != null) && !toolsClasses.isEmpty()) {
List<Object> tools = new ArrayList<>(toolsClasses.size());
for (String toolClass : toolsClasses) {
Object tool = creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(toolClass));
for (var entry : toolsClasses.entrySet()) {
AnnotationLiteral<?> qualifier = entry.getValue();
Object tool;
if (qualifier != null) {
tool = creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()),
qualifier);
} else {
tool = creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()));
}

tools.add(tool);
}
quarkusAiServices.tools(tools);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.List;
import java.util.Map;

import jakarta.enterprise.util.AnnotationLiteral;

public record DeclarativeAiServiceCreateInfo(
String serviceClassName,
String languageModelSupplierClassName,
String streamingChatLanguageModelSupplierClassName,
List<String> toolsClassNames,
Map<String, AnnotationLiteral<?>> toolsClassInfo,
String toolProviderSupplier,
String chatMemoryProviderSupplierClassName,
String retrieverClassName,
Expand Down
5 changes: 5 additions & 0 deletions model-providers/openai/openai-vanilla/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.smallrye.certs</groupId>
<artifactId>smallrye-certificate-generator-junit5</artifactId>
Expand Down
Loading

0 comments on commit 4fdab1d

Please sign in to comment.