Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tools): Basic tool support for OpenAI models #1447

Merged
merged 11 commits into from
Oct 4, 2024
129 changes: 122 additions & 7 deletions src/lib/server/endpoints/openai/endpointOai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,84 @@ import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
import type {
ChatCompletionCreateParamsStreaming,
ChatCompletionTool,
} from "openai/resources/chat/completions";
import type { FunctionDefinition, FunctionParameters } from "openai/resources/shared";
import { buildPrompt } from "$lib/buildPrompt";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
import type OpenAI from "openai";
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
import type { MessageFile } from "$lib/types/Message";
import { type Tool } from "$lib/types/Tool";
import type { EndpointMessage } from "../endpoints";
import { v4 as uuidv4 } from "uuid";
function createChatCompletionToolsArray(tools: Tool[] | undefined): ChatCompletionTool[] {
const toolChoices = [] as ChatCompletionTool[];
if (tools === undefined) {
return toolChoices;
}

for (const t of tools) {
const requiredProperties = [] as string[];

const properties = {} as Record<string, unknown>;
for (const idx in t.inputs) {
const parameterDefinition = t.inputs[idx];

const parameter = {} as Record<string, unknown>;
switch (parameterDefinition.type) {
case "str":
parameter.type = "string";
break;
case "float":
case "int":
parameter.type = "number";
break;
case "bool":
parameter.type = "boolean";
break;
case "file":
throw new Error("File type's currently not supported");
default:
throw new Error(`Unknown tool IO type: ${t}`);
}

if ("description" in parameterDefinition) {
parameter.description = parameterDefinition.description;
}

if (parameterDefinition.paramType == "required") {
requiredProperties.push(t.inputs[idx].name);
}

properties[t.inputs[idx].name] = parameter;
}

const functionParameters: FunctionParameters = {
type: "object",
...(requiredProperties.length > 0 ? { required: requiredProperties } : {}),
properties,
};

const functionDefinition: FunctionDefinition = {
name: t.name,
description: t.description,
parameters: functionParameters,
};

const toolDefinition: ChatCompletionTool = {
type: "function",
function: functionDefinition,
};

toolChoices.push(toolDefinition);
}

return toolChoices;
}

export const endpointOAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand Down Expand Up @@ -57,7 +127,6 @@ export async function endpointOai(
extraBody,
} = endpointOAIParametersSchema.parse(input);

/* eslint-disable-next-line no-shadow */
let OpenAI;
try {
OpenAI = (await import("openai")).OpenAI;
Expand All @@ -75,6 +144,11 @@ export async function endpointOai(
const imageProcessor = makeImageProcessor(multimodal.image);

if (completion === "completions") {
if (model.tools) {
throw new Error(
"Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
);
}
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
Expand Down Expand Up @@ -102,9 +176,9 @@ export async function endpointOai(
return openAICompletionToTextGenerationStream(openAICompletion);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings }) => {
return async ({ messages, preprompt, generateSettings, tools, toolResults }) => {
let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
await prepareMessages(messages, imageProcessor);
await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);

if (messagesOpenAI?.[0]?.role !== "system") {
messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
Expand All @@ -114,7 +188,44 @@ export async function endpointOai(
messagesOpenAI[0].content = preprompt ?? "";
}

if (toolResults && toolResults.length > 0) {
const toolCallRequests: OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam = {
role: "assistant",
content: null,
tool_calls: [],
};

const responses: Array<OpenAI.Chat.Completions.ChatCompletionToolMessageParam> = [];

for (const result of toolResults) {
const id = uuidv4();

const toolCallResult: OpenAI.Chat.Completions.ChatCompletionMessageToolCall = {
type: "function",
function: {
name: result.call.name,
arguments: JSON.stringify(result.call.parameters),
},
id,
};

toolCallRequests.tool_calls?.push(toolCallResult);
const toolCallResponse: OpenAI.Chat.Completions.ChatCompletionToolMessageParam = {
role: "tool",
content: "",
tool_call_id: id,
};
if ("outputs" in result) {
toolCallResponse.content = JSON.stringify(result.outputs);
}
responses.push(toolCallResponse);
}
messagesOpenAI.push(toolCallRequests);
messagesOpenAI.push(...responses);
}

const parameters = { ...model.parameters, ...generateSettings };
const toolCallChoices = createChatCompletionToolsArray(tools);
const body: ChatCompletionCreateParamsStreaming = {
model: model.id ?? model.name,
messages: messagesOpenAI,
Expand All @@ -124,6 +235,7 @@ export async function endpointOai(
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
...(toolCallChoices.length > 0 ? { tools: toolCallChoices, tool_choice: "auto" } : {}),
};

const openChatAICompletion = await openai.chat.completions.create(body, {
Expand All @@ -139,11 +251,12 @@ export async function endpointOai(

async function prepareMessages(
messages: EndpointMessage[],
imageProcessor: ReturnType<typeof makeImageProcessor>
imageProcessor: ReturnType<typeof makeImageProcessor>,
isMultimodal: boolean
): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
return Promise.all(
messages.map(async (message) => {
if (message.from === "user") {
if (message.from === "user" && isMultimodal) {
return {
role: message.from,
content: [
Expand All @@ -164,7 +277,9 @@ async function prepareFiles(
imageProcessor: ReturnType<typeof makeImageProcessor>,
files: MessageFile[]
): Promise<OpenAI.Chat.Completions.ChatCompletionContentPartImage[]> {
const processedFiles = await Promise.all(files.map(imageProcessor));
const processedFiles = await Promise.all(
files.filter((file) => file.mime.startsWith("image/")).map(imageProcessor)
);
return processedFiles.map((file) => ({
type: "image_url" as const,
image_url: {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type OpenAI from "openai";
import type { Stream } from "openai/streaming";
import type { ToolCall } from "$lib/types/Tool";

type ToolCallWithParameters = {
toolCall: ToolCall;
parameterJsonString: string;
};

function prepareToolCalls(toolCallsWithParameters: ToolCallWithParameters[], tokenId: number) {
const toolCalls: ToolCall[] = [];

for (const toolCallWithParameters of toolCallsWithParameters) {
// HACK: sometimes gpt4 via azure returns the JSON with literal newlines in it
// like {\n "foo": "bar" }
const s = toolCallWithParameters.parameterJsonString.replace("\n", "");
const params = JSON.parse(s);

const toolCall = toolCallWithParameters.toolCall;
for (const name in params) {
toolCall.parameters[name] = params[name];
}

toolCalls.push(toolCall);
}

const output = {
token: {
id: tokenId,
text: "",
logprob: 0,
special: false,
toolCalls,
},
generated_text: null,
details: null,
};

return output;
}

/**
* Transform a stream of OpenAI.Chat.ChatCompletion into a stream of TextGenerationStreamOutput
Expand All @@ -10,6 +48,7 @@ export async function* openAIChatToTextGenerationStream(
) {
let generatedText = "";
let tokenId = 0;
const toolCalls: ToolCallWithParameters[] = [];
for await (const completion of completionStream) {
const { choices } = completion;
const content = choices[0]?.delta?.content ?? "";
Expand All @@ -28,5 +67,30 @@ export async function* openAIChatToTextGenerationStream(
details: null,
};
yield output;

const tools = completion.choices[0]?.delta?.tool_calls || [];
for (const tool of tools) {
if (tool.id) {
if (!tool.function?.name) {
throw new Error("Tool call without function name");
}
const toolCallWithParameters: ToolCallWithParameters = {
toolCall: {
name: tool.function.name,
parameters: {},
},
parameterJsonString: "",
};
toolCalls.push(toolCallWithParameters);
}

if (toolCalls.length > 0 && tool.function?.arguments) {
toolCalls[toolCalls.length - 1].parameterJsonString += tool.function.arguments;
}
}

if (choices[0]?.finish_reason === "tool_calls") {
yield prepareToolCalls(toolCalls, tokenId++);
}
}
}
Loading