Skip to content

Commit

Permalink
feat: add support for function_call option (#10)
Browse files Browse the repository at this point in the history
* feat: add support for function_call option

fixes #9

* chore: add missing semicolon
  • Loading branch information
b0o authored Sep 12, 2023
1 parent 6d8b36d commit 7b52857
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { FunctionDef, formatFunctionDefinitions } from "./functions";

type Message = OpenAI.Chat.CreateChatCompletionRequestMessage;
type Function = OpenAI.Chat.CompletionCreateParams.Function;
type FunctionCall = OpenAI.Chat.CompletionCreateParams.FunctionCallOption;

let encoder: Tiktoken | undefined;

Expand All @@ -17,9 +18,11 @@ let encoder: Tiktoken | undefined;
export function promptTokensEstimate({
messages,
functions,
function_call,
}: {
messages: Message[];
functions?: Function[];
function_call?: 'none' | 'auto' | FunctionCall;
}): number {
// It appears that if functions are present, the first system message is padded with a trailing newline. This
// was inferred by trying lots of combinations of messages and functions and seeing what the token counts were.
Expand Down Expand Up @@ -49,6 +52,13 @@ export function promptTokensEstimate({
tokens -= 4;
}

// If function_call is 'none', add one token.
// If it's a FunctionCall object, add 4 + the number of tokens in the function name.
// If it's undefined or 'auto', don't add anything.
if (function_call && function_call !== 'auto') {
tokens += function_call === 'none' ? 1 : stringTokens(function_call.name) + 4;
}

return tokens;
}

Expand Down
62 changes: 62 additions & 0 deletions tests/token-counts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ import { promptTokensEstimate } from "../src";

type Message = OpenAI.Chat.CreateChatCompletionRequestMessage;
type Function = OpenAI.Chat.CompletionCreateParams.Function;
type FunctionCall = OpenAI.Chat.CompletionCreateParams.FunctionCallOption;
type Example = {
messages: Message[];
functions?: Function[];
function_call?: "none" | "auto" | FunctionCall;
tokens: number;
validate?: boolean;
};
Expand Down Expand Up @@ -109,6 +111,39 @@ const TEST_CASES: Example[] = [
],
tokens: 31,
},
{
messages: [{ role: "user", content: "hello" }],
functions: [
{
name: "foo",
parameters: { type: "object", properties: {} },
},
],
function_call: "none",
tokens: 32,
},
{
messages: [{ role: "user", content: "hello" }],
functions: [
{
name: "foo",
parameters: { type: "object", properties: {} },
},
],
function_call: "auto",
tokens: 31,
},
{
messages: [{ role: "user", content: "hello" }],
functions: [
{
name: "foo",
parameters: { type: "object", properties: {} },
},
],
function_call: { name: "foo" },
tokens: 36,
},
{
messages: [{ role: "user", content: "hello" }],
functions: [
Expand Down Expand Up @@ -263,6 +298,31 @@ const TEST_CASES: Example[] = [
],
tokens: 40,
},
{
messages: [
{ role: "system", content: "Hello:" },
{ role: "system", content: "Hello" },
{ role: "user", content: "Hi there" },
],
functions: [
{ name: "do_stuff", parameters: { type: "object", properties: {} } },
{ name: "do_other_stuff", parameters: { type: "object", properties: {} } },
],
tokens: 49,
},
{
messages: [
{ role: "system", content: "Hello:" },
{ role: "system", content: "Hello" },
{ role: "user", content: "Hi there" },
],
functions: [
{ name: "do_stuff", parameters: { type: "object", properties: {} } },
{ name: "do_other_stuff", parameters: { type: "object", properties: {} } },
],
function_call: { name: "do_stuff" },
tokens: 55,
},
{
messages: [{ role: "user", content: "hello" }],
functions: [
Expand Down Expand Up @@ -394,6 +454,7 @@ describe.each(TEST_CASES)("token counts (%j)", (example) => {
model: "gpt-3.5-turbo",
messages: example.messages,
functions: example.functions as any,
function_call: example.function_call,
max_tokens: 10,
});
expect(response.usage?.prompt_tokens).toBe(example.tokens);
Expand All @@ -406,6 +467,7 @@ describe.each(TEST_CASES)("token counts (%j)", (example) => {
promptTokensEstimate({
messages: example.messages,
functions: example.functions,
function_call: example.function_call,
}),
).toBe(example.tokens);
});
Expand Down

0 comments on commit 7b52857

Please sign in to comment.