Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tool type support #1047

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import dev.langchain4j.agent.tool.ToolMemoryId;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.output.structured.Description;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
Expand Down Expand Up @@ -82,6 +83,16 @@ public class ToolProcessor {
Object.class);
private static final Logger log = Logger.getLogger(ToolProcessor.class);

public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional");
public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt");
public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong");
public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble");

private static final DotName DATE = DotName.createSimple("java.util.Date");
private static final DotName LOCAL_DATE = DotName.createSimple("java.time.LocalDate");
private static final DotName LOCAL_DATE_TIME = DotName.createSimple("java.time.LocalDateTime");
private static final DotName OFFSET_DATE_TIME = DotName.createSimple("java.time.OffsetDateTime");

@BuildStep
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
Expand Down Expand Up @@ -452,7 +463,15 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
|| DotNames.BIG_DECIMAL.equals(typeName)) {
return removeNulls(NUMBER, description);
}
if (LOCAL_DATE_TIME.equals(typeName) || OFFSET_DATE_TIME.equals(typeName)) {
return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date-time"),
description);
}

if (DATE.equals(typeName) || LOCAL_DATE.equals(typeName)) {
return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date"),
description);
}
// TODO something else?
if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) {
ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType()
Expand Down Expand Up @@ -487,17 +506,35 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
ClassInfo classInfo = index.getClassByName(type.name());

List<String> required = new ArrayList<>();

if (classInfo != null) {
for (FieldInfo field : classInfo.fields()) {
String fieldName = field.name();
Type fieldType = field.type();

boolean isOptional = isJavaOptionalType(fieldType);
if (isOptional) {
fieldType = unwrapOptionalType(fieldType);
}

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(fieldType, index, null);
Map<String, Object> fieldDescription = new HashMap<>();

for (JsonSchemaProperty fieldProperty : fieldSchema) {
fieldDescription.put(fieldProperty.key(), fieldProperty.value());
}

if (field.hasAnnotation(Description.class)) {
AnnotationInstance descriptionAnnotation = field.annotation(Description.class);
if (descriptionAnnotation != null && descriptionAnnotation.value() != null) {
String[] descriptionValue = descriptionAnnotation.value().asStringArray();
fieldDescription.put("description", String.join(",", descriptionValue));
}
}
if (!isOptional) {
required.add(fieldName);
}

properties.put(fieldName, fieldDescription);
}
}
Expand All @@ -509,10 +546,39 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
throw new IllegalArgumentException("Unsupported type: " + type);
}

private boolean isJavaOptionalType(Type type) {
DotName typeName = type.name();
return typeName.equals(DotName.createSimple("java.util.Optional"))
|| typeName.equals(DotName.createSimple("java.util.OptionalInt"))
|| typeName.equals(DotName.createSimple("java.util.OptionalLong"))
|| typeName.equals(DotName.createSimple("java.util.OptionalDouble"));
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) {
ParameterizedType parameterizedType = optionalType.asParameterizedType();
return parameterizedType.arguments().get(0);
}
return optionalType;
}

private boolean isComplexType(Type type) {
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
}

private boolean isOptionalField(FieldInfo field, IndexView index) {
Type fieldType = field.type();
DotName fieldTypeName = fieldType.name();

if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName)
|| OPTIONAL_DOUBLE.equals(fieldTypeName)) {
return true;
}

return false;

}

private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
return stream(properties)
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -15,6 +17,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.json.JsonReadFeature;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
Expand All @@ -27,11 +30,11 @@
public class QuarkusJsonCodecFactory implements JsonCodecFactory {

@Override
public Json.JsonCodec create() {
public Codec create() {
return new Codec();
}

private static class Codec implements Json.JsonCodec {
public static class Codec implements Json.JsonCodec {

private static final Pattern sanitizePattern = Pattern.compile("(?s)\\{.*\\}|\\[.*\\]");

Expand Down Expand Up @@ -60,6 +63,31 @@ public <T> T fromJson(String json, Class<T> type) {
}
}

public <T> T fromJson(String json, Type type) {
try {
String sanitizedJson = sanitize(json, type.getClass());
JavaType javaType = ObjectMapperHolder.MAPPER.getTypeFactory().constructType(type);
return ObjectMapperHolder.MAPPER.readValue(sanitizedJson, javaType);
} catch (JsonProcessingException e) {
if (e instanceof JsonParseException && isEnumType(type)) {
// this is the case where LangChain4j simply passes the string value of the enum to Json.fromJson()
// and Jackson does not handle it
if (type instanceof ParameterizedType) {
Class<? extends Enum> enumClass = (Class<? extends Enum>) ((ParameterizedType) type).getRawType();
return (T) Enum.valueOf(enumClass, json);
} else {

return (T) Enum.valueOf((Class<? extends Enum>) type, json);
}
}
throw new UncheckedIOException(e);
}
}

private boolean isEnumType(Type type) {
return type instanceof Class<?> && ((Class<?>) type).isEnum();
}

private <T> String sanitize(String original, Class<T> type) {
if (String.class.equals(type)) {
return original;
Expand Down
Loading
Loading