From 0c407f7c94188639ec392657bbe1d4eb33e427fc Mon Sep 17 00:00:00 2001 From: Eric Deandrea Date: Fri, 22 Nov 2024 16:43:50 -0500 Subject: [PATCH] Migrate to the JsonSchemaElement API Closes #1054 --- .../langchain4j/jlama/JlamaModel.java | 7 +------ .../langchain4j/ollama/MessageMapper.java | 16 ++-------------- .../watsonx/bean/TextChatMessage.java | 14 ++++++++------ 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java b/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java index e2a71b74d..e195edf3e 100644 --- a/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java +++ b/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java @@ -127,12 +127,7 @@ static Tool toTool(ToolSpecification toolSpecification) { .name(toolSpecification.name()) .description(toolSpecification.description()); - if (toolSpecification.toolParameters() != null) { - for (Map.Entry> p : toolSpecification.toolParameters().properties().entrySet()) { - builder.addParameter(p.getKey(), p.getValue(), - toolSpecification.toolParameters().required().contains(p.getKey())); - } - } else if (toolSpecification.parameters() != null) { + if (toolSpecification.parameters() != null) { for (Map.Entry p : toolSpecification.parameters().properties().entrySet()) { builder.addParameter(p.getKey(), JsonSchemaElementHelper.toMap(p.getValue()), toolSpecification.parameters().required().contains(p.getKey())); diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java index aea73737c..a0b017a9e 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java @@ -16,7 +16,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -142,23 +141,12 @@ static List toTools(Collection toolSpecifications) { } private static Tool toTool(ToolSpecification toolSpecification) { - Tool.Function.Parameters functionParameters; - if (toolSpecification.toolParameters() != null) { - functionParameters = toFunctionParameters(toolSpecification.toolParameters()); - } else { - functionParameters = toFunctionParameters(toolSpecification.parameters()); - } + Tool.Function.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters()); + return new Tool(Tool.Type.FUNCTION, new Tool.Function(toolSpecification.name(), toolSpecification.description(), functionParameters)); } - private static Tool.Function.Parameters toFunctionParameters(ToolParameters toolParameters) { - if (toolParameters == null) { - return Tool.Function.Parameters.empty(); - } - return Tool.Function.Parameters.objectType(toolParameters.properties(), toolParameters.required()); - } - private static Tool.Function.Parameters toFunctionParameters(JsonObjectSchema parameters) { if (parameters == null) { return Tool.Function.Parameters.empty(); diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java index 55e365f1f..091d86f36 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java @@ -16,6 +16,7 @@ import dev.langchain4j.data.message.TextContent; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageAssistant; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageTool; @@ -174,7 +175,7 @@ public static TextChatMessageTool of(ToolExecutionResultMessage toolExecutionRes /** * Creates a {@link TextChatMessageTool}. * - * @param message the content of the message tool. + * @param content the content of the message tool. * @param toolCallId the unique identifier of the message tool. * @return the created {@link TextChatMessageTool}. */ @@ -219,15 +220,16 @@ public record TextChatParameterFunction(String name, String description, Map