-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Implement OpenAI token counting using tiktoken
#3447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d7f0b87
80a61f1
cc8cbf0
c1be8c1
1332cd8
cb5da87
6396f5d
46cd331
86a0b89
bacf788
acf86b0
6d2d4dd
9943173
75f29fa
6deaea2
88a132d
7db7fb5
275f16a
6a58dfb
c16ab50
e94755a
6b07449
4aa4c99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,7 @@ | |
| from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent | ||
|
|
||
| try: | ||
| import tiktoken | ||
| from openai import NOT_GIVEN, APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit | ||
| from openai.types import AllModels, chat, responses | ||
| from openai.types.chat import ( | ||
|
|
@@ -1063,6 +1064,35 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch | |
| ) | ||
| return ChatCompletionContentPartTextParam(text=text, type='text') | ||
|
|
||
| async def count_tokens( | ||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> usage.RequestUsage: | ||
| """Count the number of tokens in the given messages.""" | ||
| if self.system != 'openai': | ||
| raise NotImplementedError('Token counting is only supported for OpenAI system.') | ||
|
|
||
| try: | ||
| encoding = await _utils.run_in_executor(tiktoken.encoding_for_model, self.model_name) | ||
| except KeyError as e: | ||
| raise ValueError( | ||
| f'The model {self.model_name!r} is not supported by tiktoken', | ||
| ) from e | ||
|
|
||
| model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters) | ||
| openai_messages = await self._map_messages(messages, model_request_parameters) | ||
| message_token_count = await _num_tokens_from_messages(openai_messages, self.model_name, encoding) | ||
|
|
||
| # Count tokens for tools/functions | ||
| tools = self._get_tools(model_request_parameters) | ||
| tools_token_count = await _num_tokens_for_tools(tools, self.model_name, encoding) | ||
|
|
||
| return usage.RequestUsage( | ||
| input_tokens=message_token_count + tools_token_count, | ||
| ) | ||
|
|
||
|
|
||
| @deprecated( | ||
| '`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which ' | ||
|
|
@@ -1908,6 +1938,36 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa | |
| assert_never(item) | ||
| return responses.EasyInputMessageParam(role='user', content=content) | ||
|
|
||
| async def count_tokens( | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> usage.RequestUsage: | ||
| """Count the number of tokens in the given messages.""" | ||
| if self.system != 'openai': | ||
| raise NotImplementedError('Token counting is only supported for OpenAI system.') | ||
|
|
||
| try: | ||
| encoding = await _utils.run_in_executor(tiktoken.encoding_for_model, self.model_name) | ||
| except KeyError as e: | ||
| raise ValueError( | ||
| f'The model {self.model_name!r} is not supported by tiktoken', | ||
| ) from e | ||
|
|
||
| model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters) | ||
| _, openai_messages = await self._map_messages( | ||
| messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters | ||
| ) | ||
| message_token_count = await _num_tokens_from_messages(openai_messages, self.model_name, encoding) | ||
|
|
||
| # Count tokens for tools/functions | ||
| tools = self._get_tools(model_request_parameters) | ||
| tools_token_count = await _num_tokens_for_tools(tools, self.model_name, encoding) | ||
| return usage.RequestUsage( | ||
| input_tokens=message_token_count + tools_token_count, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class OpenAIStreamedResponse(StreamedResponse): | ||
|
|
@@ -2680,3 +2740,153 @@ def _map_mcp_call( | |
| provider_name=provider_name, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| async def _num_tokens_from_messages( # noqa: C901 | ||
| messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam], | ||
| model: OpenAIModelName, | ||
| encoding: tiktoken.Encoding, | ||
| ) -> int: | ||
| """Return the number of tokens used by a list of messages.""" | ||
| if 'gpt-5' in model: | ||
| tokens_per_message = 3 | ||
| tokens_per_name = 1 | ||
| final_primer = 2 # "reverse engineered" based on test cases | ||
| else: | ||
| # Adapted from https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at the cookbook again, I think we should also try to implement support for counting the tokens of tool definitions: |
||
| tokens_per_message = 3 | ||
| tokens_per_name = 1 | ||
| final_primer = 3 # every reply is primed with <|start|>assistant<|message|> | ||
|
|
||
| num_tokens = 0 | ||
| for message in messages: | ||
| num_tokens += tokens_per_message | ||
| for key, value in message.items(): | ||
| if (key == 'content' or key == 'role') and isinstance(value, str): | ||
| num_tokens += len(encoding.encode(value)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this (or the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The methods which download the encoding file are
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| elif key == 'tool_calls' and isinstance(value, list): | ||
| # Chat Completions API: list of ChatCompletionToolCallParam | ||
| # Responses API: list of ResponseFunctionToolCallParam | ||
| for tool_call in value: # pyright: ignore[reportUnknownVariableType] | ||
| if isinstance(tool_call, dict): | ||
| # Both ChatCompletionToolCallParam and ResponseFunctionToolCallParam have 'function' field | ||
| num_tokens += 4 | ||
| function = tool_call.get('function', {}) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] | ||
| if function and isinstance(function, dict): | ||
| # Both have 'name' and 'arguments' fields | ||
| num_tokens += 1 | ||
| if 'name' in function and isinstance(function['name'], str): | ||
| num_tokens += len(encoding.encode(function['name'])) | ||
| if 'arguments' in function and isinstance(function['arguments'], str): | ||
| num_tokens += len(encoding.encode(function['arguments'])) | ||
| elif key == 'name': | ||
| num_tokens += tokens_per_name | ||
| elif key == 'content' and isinstance(value, list): | ||
| # Handle list content (multimodal messages) | ||
| # Chat Completions API: list of ChatCompletionContentPartParam | ||
| # Responses API: list of ResponseInputContentParam | ||
| for content_part in value: # pyright: ignore[reportUnknownVariableType] | ||
| if isinstance(content_part, dict): | ||
| # ChatCompletionContentPartTextParam has 'text' field | ||
| # ResponseInputTextParam has 'text' field | ||
| if 'text' in content_part and isinstance(content_part['text'], str): | ||
| num_tokens += len(encoding.encode(content_part['text'])) | ||
| # Note: Images, audio, files are not tokenized as text | ||
| # They have their own token costs handled by the API | ||
|
|
||
| num_tokens += final_primer | ||
| return num_tokens | ||
|
|
||
|
|
||
| async def _num_tokens_for_tools( | ||
| tools: list[chat.ChatCompletionToolParam] | list[responses.FunctionToolParam], | ||
| model: OpenAIModelName, | ||
| encoding: tiktoken.Encoding, | ||
| ) -> int: | ||
| """Return the number of tokens used by a list of tools. | ||
|
|
||
| Based on the OpenAI token counting approach for function calling. | ||
| Supports both Chat Completions API tools (ChatCompletionToolParam) and | ||
| Responses API tools (FunctionToolParam). | ||
| """ | ||
| # Initialize function settings to 0 | ||
| func_init = 0 | ||
| prop_init = 0 | ||
| prop_key = 0 | ||
| enum_init = 0 | ||
| enum_item = 0 | ||
| func_end = 0 | ||
|
|
||
| if 'gpt-4o' in model or 'gpt-4o-mini' in model: | ||
| # Set function settings for gpt-4o models | ||
| func_init = 7 | ||
| prop_init = 3 | ||
| prop_key = 3 | ||
| enum_init = -3 | ||
| enum_item = 3 | ||
| func_end = 12 | ||
| elif 'gpt-3.5-turbo' in model or 'gpt-4' in model: | ||
| # Set function settings for gpt-3.5-turbo and gpt-4 models | ||
| func_init = 10 | ||
| prop_init = 3 | ||
| prop_key = 3 | ||
| enum_init = -3 | ||
| enum_item = 3 | ||
| func_end = 12 | ||
| else: | ||
| # Default to gpt-4o settings for unknown models | ||
| func_init = 7 | ||
| prop_init = 3 | ||
| prop_key = 3 | ||
| enum_init = -3 | ||
| enum_item = 3 | ||
| func_end = 12 | ||
|
|
||
| func_token_count = 0 | ||
| if len(tools) > 0: | ||
| for tool in tools: | ||
| func_token_count += func_init # Add tokens for start of each function | ||
|
|
||
| # Handle both ChatCompletionToolParam and FunctionToolParam structures | ||
| # ChatCompletionToolParam: {'type': 'function', 'function': {'name': ..., 'description': ..., 'parameters': ...}} | ||
| # FunctionToolParam: {'type': 'function', 'name': ..., 'description': ..., 'parameters': ...} | ||
| if 'function' in tool: | ||
| # ChatCompletionToolParam format | ||
| function = tool['function'] | ||
| f_name = str(function.get('name', '')) | ||
| f_desc = str(function.get('description', '') or '') | ||
| parameters = function.get('parameters') | ||
| else: | ||
| # FunctionToolParam format (Responses API) | ||
| f_name = str(tool.get('name', '')) | ||
| f_desc = str(tool.get('description', '') or '') | ||
| parameters = tool.get('parameters') | ||
|
|
||
| if f_desc.endswith('.'): | ||
| f_desc = f_desc[:-1] | ||
| line = f'{f_name}:{f_desc}' | ||
| func_token_count += len(encoding.encode(line)) # Add tokens for function name and description | ||
|
|
||
| if parameters and isinstance(parameters, dict): | ||
| properties_raw = parameters.get('properties', {}) | ||
| if properties_raw and isinstance(properties_raw, dict) and len(properties_raw) > 0: # pyright: ignore[reportUnknownArgumentType] | ||
| func_token_count += prop_init # Add tokens for start of properties | ||
| for key, prop_value in properties_raw.items(): # pyright: ignore[reportUnknownVariableType] | ||
| if not isinstance(prop_value, dict): | ||
| continue | ||
| func_token_count += prop_key # Add tokens for each property | ||
| p_name = str(key) # pyright: ignore[reportUnknownArgumentType] | ||
| p_type = str(prop_value.get('type', '') or '') # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType] | ||
| p_desc = str(prop_value.get('description', '') or '') # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType] | ||
| if 'enum' in prop_value: | ||
| func_token_count += enum_init # Add tokens if property has enum list | ||
| for item in prop_value['enum']: # pyright: ignore[reportUnknownVariableType] | ||
| func_token_count += enum_item | ||
| func_token_count += len(encoding.encode(str(item))) # pyright: ignore[reportUnknownArgumentType] | ||
| if p_desc.endswith('.'): | ||
| p_desc = p_desc[:-1] | ||
| line = f'{p_name}:{p_type}:{p_desc}' | ||
| func_token_count += len(encoding.encode(line)) | ||
| func_token_count += func_end | ||
|
|
||
| return func_token_count | ||
Uh oh!
There was an error while loading. Please reload this page.