Skip to content

Commit

Permalink
Added support mapping to to complex types in ai service methods. Adde…
Browse files Browse the repository at this point in the history
…d data from @description annotations to json schema, for llm to better understand how to output data.
  • Loading branch information
Tarjei400 committed Nov 4, 2024
1 parent 4c72522 commit bd88ad9
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import dev.langchain4j.agent.tool.ToolMemoryId;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.output.structured.Description;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
Expand Down Expand Up @@ -498,6 +499,13 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
}

if (field.hasAnnotation(Description.class)) {
AnnotationInstance descriptionAnnotation = field.annotation(Description.class);
if (descriptionAnnotation != null && descriptionAnnotation.value() != null) {
String[] descriptionValue = descriptionAnnotation.value().asStringArray();
fieldDescription.put("description", String.join(",", descriptionValue));
}
}
properties.put(fieldName, fieldDescription);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -15,6 +17,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.json.JsonReadFeature;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
Expand All @@ -27,11 +30,11 @@
public class QuarkusJsonCodecFactory implements JsonCodecFactory {

@Override
public Json.JsonCodec create() {
public Codec create() {
return new Codec();
}

private static class Codec implements Json.JsonCodec {
public static class Codec implements Json.JsonCodec {

private static final Pattern sanitizePattern = Pattern.compile("(?s)\\{.*\\}|\\[.*\\]");

Expand Down Expand Up @@ -60,6 +63,26 @@ public <T> T fromJson(String json, Class<T> type) {
}
}

public <T> T fromJson(String json, Type type) {
try {
String sanitizedJson = sanitize(json, type.getClass());
JavaType javaType = ObjectMapperHolder.MAPPER.getTypeFactory().constructType(type);
return ObjectMapperHolder.MAPPER.readValue(sanitizedJson, javaType);
} catch (JsonProcessingException e) {
if (e instanceof JsonParseException && isEnumType(type)) {
// this is the case where LangChain4j simply passes the string value of the enum to Json.fromJson()
// and Jackson does not handle it
Class<? extends Enum> enumClass = (Class<? extends Enum>) ((ParameterizedType) type).getRawType();
return (T) Enum.valueOf(enumClass, json);
}
throw new UncheckedIOException(e);
}
}

private boolean isEnumType(Type type) {
return type instanceof Class<?> && ((Class<?>) type).isEnum();
}

private <T> String sanitize(String original, Class<T> type) {
if (String.class.equals(type)) {
return original;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,163 @@
package io.quarkiverse.langchain4j.runtime;

import static dev.langchain4j.service.TypeUtils.getRawClass;
import java.lang.reflect.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import java.lang.reflect.Type;
import com.fasterxml.jackson.databind.ObjectMapper;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
//import dev.langchain4j.service.output.OutputParser;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.smallrye.mutiny.Multi;

public class QuarkusServiceOutputParser extends ServiceOutputParser {
private static final Pattern JSON_BLOCK_PATTERN = Pattern.compile("(?s)\\{.*\\}|\\[.*\\]");

@Override
public String outputFormatInstructions(Type returnType) {
Class<?> rawClass = getRawClass(returnType);
if (Multi.class.equals(rawClass)) {
// when Multi is used as the return type, Multi<String> is the only supported type, thus we don't need want any formatting instructions
return "";

if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class
&& rawClass != Response.class && !Multi.class.equals(rawClass)) {
try {
var schema = this.toJsonSchema(returnType);
return "You must answer strictly with json according to the following json schema format: " + schema;
} catch (Exception e) {
return "";
}
}

return "";
}

public Object parse(Response<AiMessage> response, Type returnType) {
QuarkusJsonCodecFactory factory = new QuarkusJsonCodecFactory();
var codec = factory.create();

if (TypeUtils.typeHasRawClass(returnType, Result.class)) {
returnType = TypeUtils.resolveFirstGenericParameterClass(returnType);
}

Class<?> rawReturnClass = TypeUtils.getRawClass(returnType);

if (rawReturnClass == Response.class) {
return response;
} else {
AiMessage aiMessage = response.content();
if (rawReturnClass == AiMessage.class) {
return aiMessage;
} else {
String text = aiMessage.text();
if (rawReturnClass == String.class) {
return text;
} else {
try {
return codec.fromJson(text, returnType);
} catch (Exception var10) {
String jsonBlock = this.extractJsonBlock(text);
return codec.fromJson(jsonBlock, returnType);
}
}
}
}
}

private String extractJsonBlock(String text) {
Matcher matcher = JSON_BLOCK_PATTERN.matcher(text);
return matcher.find() ? matcher.group() : text;
}

public String toJsonSchema(Type type) throws Exception {
Map<String, Object> schema = new HashMap<>();
Class<?> rawClass = getRawClass(type);

if (type instanceof WildcardType wildcardType) {
Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0]
: wildcardType.getLowerBounds()[0];
return toJsonSchema(boundType);
}

if (rawClass == String.class || rawClass == Character.class) {
schema.put("type", "string");
} else if (rawClass == Boolean.class || rawClass == boolean.class) {
schema.put("type", "boolean");
} else if (Number.class.isAssignableFrom(rawClass) || rawClass.isPrimitive()) {
schema.put("type", (rawClass == double.class || rawClass == float.class) ? "number" : "integer");
} else if (Collection.class.isAssignableFrom(rawClass) || rawClass.isArray()) {
schema.put("type", "array");

Type elementType = getElementType(type);
Map<String, Object> itemsSchema = toJsonSchemaMap(elementType);
schema.put("items", itemsSchema);
} else if (rawClass.isEnum()) {
schema.put("type", "string");
schema.put("enum", getEnumConstants(rawClass));
} else {
schema.put("type", "object");
Map<String, Object> properties = new HashMap<>();

for (Field field : rawClass.getDeclaredFields()) {
field.setAccessible(true);
Map<String, Object> fieldSchema = toJsonSchemaMap(field.getGenericType());
properties.put(field.getName(), fieldSchema);
if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", description.value());
}
}
schema.put("properties", properties);
}

ObjectMapper mapper = new ObjectMapper();
return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string
}

private Class<?> getRawClass(Type type) {
if (type instanceof Class<?>) {
return (Class<?>) type;
} else if (type instanceof ParameterizedType) {
return (Class<?>) ((ParameterizedType) type).getRawType();
} else if (type instanceof GenericArrayType) {
Type componentType = ((GenericArrayType) type).getGenericComponentType();
return Array.newInstance(getRawClass(componentType), 0).getClass();
} else if (type instanceof WildcardType) {
Type boundType = ((WildcardType) type).getUpperBounds().length > 0 ? ((WildcardType) type).getUpperBounds()[0]
: ((WildcardType) type).getLowerBounds()[0];
return getRawClass(boundType);
}
throw new IllegalArgumentException("Unsupported type: " + type);
}

private Type getElementType(Type type) {
if (type instanceof ParameterizedType) {
return ((ParameterizedType) type).getActualTypeArguments()[0];
} else if (type instanceof GenericArrayType) {
return ((GenericArrayType) type).getGenericComponentType();
} else if (type instanceof Class<?> && ((Class<?>) type).isArray()) {
return ((Class<?>) type).getComponentType();
}
return Object.class; // Fallback for cases where element type cannot be determined
}

private Map<String, Object> toJsonSchemaMap(Type type) throws Exception {
String jsonSchema = toJsonSchema(type);
ObjectMapper mapper = new ObjectMapper();
return mapper.readValue(jsonSchema, Map.class);
}

private List<String> getEnumConstants(Class<?> enumClass) {
List<String> constants = new ArrayList<>();
for (Object constant : enumClass.getEnumConstants()) {
constants.add(constant.toString());
}
return super.outputFormatInstructions(returnType);
return constants;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.output.structured.Description;
import io.quarkiverse.langchain4j.RegisterAiService;

@Path("assistant-with-tool")
Expand All @@ -28,8 +29,13 @@ public AssistantWithToolsResource(Assistant assistant) {
}

public static class TestData {
@Description("Foo description for structured output")
String foo;

@Description("Foo description for structured output")
Integer bar;

@Description("Foo description for structured output")
Double baz;

TestData(String foo, Integer bar, Double baz) {
Expand All @@ -44,10 +50,18 @@ public String get(@RestQuery String message) {
return assistant.chat(message);
}

@GET
@Path("/many")
public List<TestData> getMany(@RestQuery String message) {
return assistant.chats(message);
}

@RegisterAiService(tools = Calculator.class, chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class)
public interface Assistant {

String chat(String userMessage);

List<TestData> chats(String userMessage);
}

@Singleton
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.acme.example.openai.aiservices;

import java.util.ArrayList;
import java.util.List;

import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;

import org.jboss.resteasy.reactive.RestQuery;

import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@Path("collection-entity-mapping")
public class EntityMappedResource {

private final EntityMappedDescriber describer;

public EntityMappedResource(EntityMappedDescriber describer) {
this.describer = describer;
}

public static class TestData {
@Description("Foo description for structured output")
String foo;

@Description("Foo description for structured output")
Integer bar;

@Description("Foo description for structured output")
Double baz;

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
}
}

@POST
public List<String> generate(@RestQuery String message) {
var result = describer.describe(message);

return result;
}

@POST
@Path("generateMapped")
public List<TestData> generateMapped(@RestQuery String message) {
List<TestData> inputs = new ArrayList<>();
inputs.add(new TestData(message, 100, 100.0));

return describer.describeMapped(inputs);
}

@RegisterAiService
public interface EntityMappedDescriber {

@UserMessage("This is a describer returning a collection of strings")
List<String> describe(String url);

@UserMessage("This is a describer returning a collection of mapped entities")
List<TestData> describeMapped(List<TestData> inputs);
}
}
Loading

0 comments on commit bd88ad9

Please sign in to comment.