Skip to content

Commit

Permalink
Migrate to the JsonSchemaElement API
Browse files Browse the repository at this point in the history
Closes #1054
  • Loading branch information
edeandrea committed Nov 25, 2024
1 parent 2103f07 commit 23bf2a8
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 102 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonArraySchemaObjectSubstitution
implements ObjectSubstitution<JsonArraySchema, JsonArraySchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonArraySchema obj) {
return new Serialized(obj.description(), obj.items());
}

@Override
public JsonArraySchema deserialize(Serialized obj) {
return JsonArraySchema.builder()
.description(obj.description)
.items(obj.items)
.build();
}

public record Serialized(String description, JsonSchemaElement items) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonBooleanSchemaObjectSubstitution
implements ObjectSubstitution<JsonBooleanSchema, JsonBooleanSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonBooleanSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonBooleanSchema deserialize(Serialized obj) {
return JsonBooleanSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;

import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonEnumSchemaObjectSubstitution
implements ObjectSubstitution<JsonEnumSchema, JsonEnumSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonEnumSchema obj) {
return new Serialized(obj.description(), obj.enumValues());
}

@Override
public JsonEnumSchema deserialize(Serialized obj) {
return JsonEnumSchema.builder()
.description(obj.description)
.enumValues(obj.enumValues)
.build();
}

public record Serialized(String description, List<String> enumValues) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonIntegerSchemaObjectSubstitution
implements ObjectSubstitution<JsonIntegerSchema, JsonIntegerSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonIntegerSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonIntegerSchema deserialize(Serialized obj) {
return JsonIntegerSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonNumberSchemaObjectSubstitution
implements ObjectSubstitution<JsonNumberSchema, JsonNumberSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonNumberSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonNumberSchema deserialize(Serialized obj) {
return JsonNumberSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;
import java.util.Map;

import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonObjectSchemaObjectSubstitution
implements ObjectSubstitution<JsonObjectSchema, JsonObjectSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonObjectSchema obj) {
return new Serialized(obj.description(), obj.properties(), obj.required(), obj.additionalProperties(),
obj.definitions());
}

@Override
public JsonObjectSchema deserialize(Serialized obj) {
return JsonObjectSchema.builder()
.description(obj.description)
.properties(obj.properties)
.required(obj.required)
.additionalProperties(obj.additionalProperties)
.definitions(obj.definitions)
.build();
}

public record Serialized(String description, Map<String, JsonSchemaElement> properties, List<String> required,
Boolean additionalProperties, Map<String, JsonSchemaElement> definitions) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public class JsonReferenceSchemaObjectSubstitution
implements ObjectSubstitution<JsonReferenceSchema, JsonReferenceSchemaObjectSubstitution.Serialized> {
public Serialized serialize(JsonReferenceSchema obj) {
return new Serialized(obj.reference());
}

public JsonReferenceSchema deserialize(Serialized obj) {
return JsonReferenceSchema.builder()
.reference(obj.reference)
.build();
}

public record Serialized(String reference) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.runtime.tool;

import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

public final class JsonStringSchemaObjectSubstitution
implements ObjectSubstitution<JsonStringSchema, JsonStringSchemaObjectSubstitution.Serialized> {
@Override
public Serialized serialize(JsonStringSchema obj) {
return new Serialized(obj.description());
}

@Override
public JsonStringSchema deserialize(Serialized obj) {
return JsonStringSchema.builder()
.description(obj.description)
.build();
}

public record Serialized(String description) {
@RecordableConstructor
public Serialized {
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
import io.quarkus.runtime.ObjectSubstitution;
import io.quarkus.runtime.annotations.RecordableConstructor;

/**
* @deprecated
* @see JsonArraySchemaObjectSubstitution
* @see JsonBooleanSchemaObjectSubstitution
* @see JsonEnumSchemaObjectSubstitution
* @see JsonIntegerSchemaObjectSubstitution
* @see JsonNumberSchemaObjectSubstitution
* @see JsonObjectSchemaObjectSubstitution
* @see JsonReferenceSchemaObjectSubstitution
* @see JsonStringSchemaObjectSubstitution
*/
@Deprecated(forRemoval = true)
public class ToolParametersObjectSubstitution
implements ObjectSubstitution<ToolParameters, ToolParametersObjectSubstitution.Serialized> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

import org.jboss.resteasy.reactive.RestQuery;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.output.structured.Description;
import io.quarkiverse.langchain4j.RegisterAiService;

@Path("assistant-with-tool")
Expand All @@ -27,8 +29,12 @@ public AssistantWithToolsResource(Assistant assistant) {
this.assistant = assistant;
}

@Description("Some test data")
public static class TestData {
@Description("The foo field")
String foo;

@Description("The bar field")
Integer bar;
Double baz;

Expand All @@ -54,8 +60,8 @@ public interface Assistant {
public static class Calculator {

@Tool("Calculates the length of a string")
int stringLength(String s) {
return s.length();
int stringLength(@P(value = "The string to compute the length of", required = false) String s) {
return (s == null) ? 0 : s.length();
}

@Tool("Calculates the sum of two numbers")
Expand All @@ -80,7 +86,7 @@ public TestData evaluateTestObject(List<TestData> data) {
}

@Tool("Calculates all factors of the provided integer.")
List<Integer> getFactors(int x) {
List<Integer> getFactors(@P("The integer to get factor") int x) {
return java.util.stream.IntStream.rangeClosed(1, x)
.filter(i -> x % i == 0)
.boxed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ static Tool toTool(ToolSpecification toolSpecification) {
.name(toolSpecification.name())
.description(toolSpecification.description());

if (toolSpecification.toolParameters() != null) {
for (Map.Entry<String, Map<String, Object>> p : toolSpecification.toolParameters().properties().entrySet()) {
builder.addParameter(p.getKey(), p.getValue(),
toolSpecification.toolParameters().required().contains(p.getKey()));
}
} else if (toolSpecification.parameters() != null) {
if (toolSpecification.parameters() != null) {
for (Map.Entry<String, JsonSchemaElement> p : toolSpecification.parameters().properties().entrySet()) {
builder.addParameter(p.getKey(), JsonSchemaElementHelper.toMap(p.getValue()),
toolSpecification.parameters().required().contains(p.getKey()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.fasterxml.jackson.core.type.TypeReference;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -142,23 +141,12 @@ static List<Tool> toTools(Collection<ToolSpecification> toolSpecifications) {
}

private static Tool toTool(ToolSpecification toolSpecification) {
Tool.Function.Parameters functionParameters;
if (toolSpecification.toolParameters() != null) {
functionParameters = toFunctionParameters(toolSpecification.toolParameters());
} else {
functionParameters = toFunctionParameters(toolSpecification.parameters());
}
Tool.Function.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters());

return new Tool(Tool.Type.FUNCTION, new Tool.Function(toolSpecification.name(), toolSpecification.description(),
functionParameters));
}

private static Tool.Function.Parameters toFunctionParameters(ToolParameters toolParameters) {
if (toolParameters == null) {
return Tool.Function.Parameters.empty();
}
return Tool.Function.Parameters.objectType(toolParameters.properties(), toolParameters.required());
}

private static Tool.Function.Parameters toFunctionParameters(JsonObjectSchema parameters) {
if (parameters == null) {
return Tool.Function.Parameters.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper;
import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageAssistant;
import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem;
import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageTool;
Expand Down Expand Up @@ -174,7 +175,7 @@ public static TextChatMessageTool of(ToolExecutionResultMessage toolExecutionRes
/**
* Creates a {@link TextChatMessageTool}.
*
* @param message the content of the message tool.
* @param content the content of the message tool.
* @param toolCallId the unique identifier of the message tool.
* @return the created {@link TextChatMessageTool}.
*/
Expand Down Expand Up @@ -219,15 +220,16 @@ public record TextChatParameterFunction(String name, String description, Map<Str
/**
* Creates a {@link TextChatParameterTool} from a {@link ToolSpecification}.
*
* @param toolExecutionRequest the tool specification to convert
* @param toolSpecification the tool specification to convert
* @return the created {@link TextChatParameterTool}
*/
public static TextChatParameterTool of(ToolSpecification toolSpecification) {
// FIXME: toolSpecification.toolParameters() is deprecated, we might receive a value in parameters() instead
var toolParams = JsonSchemaElementHelper.toMap(toolSpecification.parameters());

var parameters = new TextChatParameterFunction(toolSpecification.name(), toolSpecification.description(), Map.of(
"type", toolSpecification.toolParameters().type(),
"properties", toolSpecification.toolParameters().properties(),
"required", toolSpecification.toolParameters().required()));
"type", toolParams.get("type"),
"properties", toolParams.get("properties"),
"required", toolParams.get("required")));
return new TextChatParameterTool("function", parameters);
}
}
Expand Down

0 comments on commit 23bf2a8

Please sign in to comment.