Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Rest Client and AI Service to be used as tools #1157

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
<optional>true</optional> <!-- conditional dependency -->
</dependency>

<dependency>
<groupId>org.eclipse.microprofile.rest.client</groupId>
<artifactId>microprofile-rest-client-api</artifactId>
</dependency>

<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx-http-dev-ui-tests</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.Dependent;
import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.enterprise.util.AnnotationLiteral;
import jakarta.inject.Inject;

import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
Expand Down Expand Up @@ -77,6 +79,7 @@
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ToolQualifierProvider;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand Down Expand Up @@ -262,11 +265,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
chatModelNames.add(chatModelName);
}

List<DotName> toolDotNames = Collections.emptyList();
List<ClassInfo> toolClassInfos = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
toolClassInfos = Arrays.stream(toolsInstance.asClassArray()).map(t -> {
var ci = index.getClassByName(t.name());
if (ci == null) {
throw new IllegalArgumentException("Cannot find class " + t.name()
+ " in index. Please make sure it's a valid CDI bean known to Quarkus");
}
return ci;
})
.toList();
}

// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
Expand Down Expand Up @@ -397,7 +407,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
streamingChatLanguageModelSupplierClassDotName,
toolDotNames,
toolClassInfos,
chatMemoryProviderSupplierClassDotName,
retrieverClassDotName,
retrievalAugmentorSupplierClassName,
Expand Down Expand Up @@ -476,11 +486,27 @@ private boolean isImageOrImageResultResult(Type returnType) {
return false;
}

@BuildStep
public void toolQualifiers(BuildProducer<ToolQualifierProvider.BuildItem> producer) {
producer.produce(new ToolQualifierProvider.BuildItem(new ToolQualifierProvider() {
@Override
public boolean supports(ClassInfo classInfo) {
return classInfo.hasAnnotation(DotNames.REGISTER_REST_CLIENT);
}

@Override
public AnnotationLiteral<?> qualifier(ClassInfo classInfo) {
return new RestClient.RestClientLiteral();
}
}));
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleDeclarativeServices(AiServicesRecorder recorder,
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
List<SelectedChatModelProviderBuildItem> selectedChatModelProvider,
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems,
BuildProducer<SyntheticBeanBuildItem> syntheticBeanProducer,
BuildProducer<UnremovableBeanBuildItem> unremovableProducer) {

Expand All @@ -507,7 +533,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getStreamingChatLanguageModelSupplierClassDotName().toString()
: null);

List<String> toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
List<ToolQualifierProvider> toolQualifierProviders = toolQualifierProviderItems.stream().map(
ToolQualifierProvider.BuildItem::getProvider).toList();
Map<String, AnnotationLiteral<?>> toolToQualifierMap = new HashMap<>();
for (ClassInfo ci : bi.getToolClassInfos()) {
AnnotationLiteral<?> qualifier = null;
for (ToolQualifierProvider provider : toolQualifierProviders) {
if (provider.supports(ci)) {
qualifier = provider.qualifier(ci);
break;
}
}
toolToQualifierMap.put(ci.name().toString(), qualifier);
}

String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null
? bi.getToolProviderClassDotName().toString()
Expand Down Expand Up @@ -597,7 +635,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
serviceClassName,
chatLanguageModelSupplierClassName,
streamingChatLanguageModelSupplierClassName,
toolClassNames,
toolToQualifierMap,
toolProviderSupplierClassName,
chatMemoryProviderSupplierClassName, retrieverClassName,
retrievalAugmentorSupplierClassName,
Expand Down Expand Up @@ -639,12 +677,16 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsChatModelBean = true;
}

if (!toolClassNames.isEmpty()) {
for (String toolClassName : toolClassNames) {
DotName dotName = DotName.createSimple(toolClassName);
for (var entry : toolToQualifierMap.entrySet()) {
DotName dotName = DotName.createSimple(entry.getKey());
AnnotationLiteral<?> qualifier = entry.getValue();
if (qualifier == null) {
configurator.addInjectionPoint(ClassType.create(dotName));
allToolNames.add(dotName);
} else {
configurator.addInjectionPoint(ClassType.create(dotName),
AnnotationInstance.builder(qualifier.annotationType()).build());
}
allToolNames.add(dotName);
}

if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final ClassInfo serviceClassInfo;
private final DotName chatLanguageModelSupplierClassDotName;
private final DotName streamingChatLanguageModelSupplierClassDotName;
private final List<DotName> toolDotNames;
private final List<ClassInfo> toolClassInfos;
private final DotName toolProviderClassDotName;

private final DotName chatMemoryProviderSupplierClassDotName;
Expand All @@ -37,7 +37,7 @@ public DeclarativeAiServiceBuildItem(
ClassInfo serviceClassInfo,
DotName chatLanguageModelSupplierClassDotName,
DotName streamingChatLanguageModelSupplierClassDotName,
List<DotName> toolDotNames,
List<ClassInfo> toolClassInfos,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverClassDotName,
DotName retrievalAugmentorSupplierClassDotName,
Expand All @@ -55,7 +55,7 @@ public DeclarativeAiServiceBuildItem(
this.serviceClassInfo = serviceClassInfo;
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.toolClassInfos = toolClassInfos;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverClassDotName = retrieverClassDotName;
this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName;
Expand Down Expand Up @@ -84,8 +84,8 @@ public DotName getStreamingChatLanguageModelSupplierClassDotName() {
return streamingChatLanguageModelSupplierClassDotName;
}

public List<DotName> getToolDotNames() {
return toolDotNames;
public List<ClassInfo> getToolClassInfos() {
return toolClassInfos;
}

public DotName getChatMemoryProviderSupplierClassDotName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import jakarta.enterprise.inject.Instance;

import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
import org.jboss.jandex.DotName;

import dev.langchain4j.agent.tool.Tool;
Expand Down Expand Up @@ -62,6 +63,8 @@ public class DotNames {
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
public static final DotName TOOL = DotName.createSimple(Tool.class);

public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class);

public static final DotName OUTPUT_GUARDRAIL_ACCUMULATOR = DotName.createSimple(OutputGuardrailAccumulator.class);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public class ToolProcessor {
private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class);
public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class,
Object.class);
private static final ResultHandle[] EMPTY_RESULT_HANDLE_ARRAY = new ResultHandle[0];

private static final Logger log = Logger.getLogger(ToolProcessor.class);

Expand Down Expand Up @@ -136,7 +137,19 @@ public void handleTools(

MethodInfo methodInfo = instance.target().asMethod();
ClassInfo classInfo = methodInfo.declaringClass();
if (classInfo.isInterface() || Modifier.isAbstract(classInfo.flags())) {
boolean causeValidationError = false;
if (classInfo.isInterface()) {

if (classInfo.hasAnnotation(LangChain4jDotNames.REGISTER_AI_SERVICES) || classInfo.hasAnnotation(
DotNames.REGISTER_REST_CLIENT)) {
// we allow tools on method of these interfaces because we know they will be beans
} else {
causeValidationError = true;
}
} else if (Modifier.isAbstract(classInfo.flags())) {
causeValidationError = true;
}
if (causeValidationError) {
validation.produce(
new ValidationPhaseBuildItem.ValidationErrorBuildItem(new IllegalStateException(
"@Tool is only supported on non-abstract classes, all other usages are ignored. Offending method is '"
Expand Down Expand Up @@ -409,16 +422,21 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu
MethodDescriptor.ofMethod(implClassName, "invoke", Object.class, Object.class, Object[].class));

ResultHandle result;
ResultHandle[] targetMethodHandles = EMPTY_RESULT_HANDLE_ARRAY;
if (methodInfo.parametersCount() > 0) {
List<ResultHandle> argumentHandles = new ArrayList<>(methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
argumentHandles.add(invokeMc.readArrayValue(invokeMc.getMethodParam(1), i));
}
ResultHandle[] targetMethodHandles = argumentHandles.toArray(new ResultHandle[0]);
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
targetMethodHandles = argumentHandles.toArray(EMPTY_RESULT_HANDLE_ARRAY);
}

if (methodInfo.declaringClass().isInterface()) {
result = invokeMc.invokeInterfaceMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
targetMethodHandles);
} else {
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0));
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
targetMethodHandles);
}

boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private void addEmbeddingStorePage(CardPageBuildItem card) {
private void addAiServicesPage(CardPageBuildItem card, List<DeclarativeAiServiceBuildItem> aiServices) {
List<AiServiceInfo> infos = new ArrayList<>();
for (DeclarativeAiServiceBuildItem aiService : aiServices) {
List<String> tools = aiService.getToolDotNames().stream().map(dotName -> dotName.toString()).toList();
List<String> tools = aiService.getToolClassInfos().stream().map(ci -> ci.name().toString()).toList();
infos.add(new AiServiceInfo(aiService.getServiceClassInfo().name().toString(), tools));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.quarkiverse.langchain4j.deployment.items;

import jakarta.enterprise.util.AnnotationLiteral;

import org.jboss.jandex.ClassInfo;

import io.quarkus.builder.item.MultiBuildItem;

/**
* Used to determine if a class containing a tool should be used along with a CDI qualifier
*/
public interface ToolQualifierProvider {

boolean supports(ClassInfo classInfo);

AnnotationLiteral<?> qualifier(ClassInfo classInfo);

final class BuildItem extends MultiBuildItem {

private final ToolQualifierProvider provider;

public BuildItem(ToolQualifierProvider provider) {
this.provider = provider;
}

public ToolQualifierProvider getProvider() {
return provider;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.util.AnnotationLiteral;
import jakarta.enterprise.util.TypeLiteral;

import dev.langchain4j.data.segment.TextSegment;
Expand Down Expand Up @@ -148,12 +149,21 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

List<String> toolsClasses = info.toolsClassNames();
Map<String, AnnotationLiteral<?>> toolsClasses = info.toolsClassInfo();
if ((toolsClasses != null) && !toolsClasses.isEmpty()) {
List<Object> tools = new ArrayList<>(toolsClasses.size());
for (String toolClass : toolsClasses) {
Object tool = creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(toolClass));
for (var entry : toolsClasses.entrySet()) {
AnnotationLiteral<?> qualifier = entry.getValue();
Object tool;
if (qualifier != null) {
tool = creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()),
qualifier);
} else {
tool = creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()));
}

tools.add(tool);
}
quarkusAiServices.tools(tools);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.List;
import java.util.Map;

import jakarta.enterprise.util.AnnotationLiteral;

public record DeclarativeAiServiceCreateInfo(
String serviceClassName,
String languageModelSupplierClassName,
String streamingChatLanguageModelSupplierClassName,
List<String> toolsClassNames,
Map<String, AnnotationLiteral<?>> toolsClassInfo,
String toolProviderSupplier,
String chatMemoryProviderSupplierClassName,
String retrieverClassName,
Expand Down
5 changes: 5 additions & 0 deletions model-providers/openai/openai-vanilla/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.smallrye.certs</groupId>
<artifactId>smallrye-certificate-generator-junit5</artifactId>
Expand Down
Loading
Loading