Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent

try:
import tiktoken
from openai import NOT_GIVEN, APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit
from openai.types import AllModels, chat, responses
from openai.types.chat import (
Expand Down Expand Up @@ -1063,6 +1064,35 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch
)
return ChatCompletionContentPartTextParam(text=text, type='text')

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Count the number of tokens in the given messages."""
if self.system != 'openai':
raise NotImplementedError('Token counting is only supported for OpenAI system.')

try:
encoding = await _utils.run_in_executor(tiktoken.encoding_for_model, self.model_name)
except KeyError as e:
raise ValueError(
f'The model {self.model_name!r} is not supported by tiktoken',
) from e

model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters)
openai_messages = await self._map_messages(messages, model_request_parameters)
message_token_count = await _num_tokens_from_messages(openai_messages, self.model_name, encoding)

# Count tokens for tools/functions
tools = self._get_tools(model_request_parameters)
tools_token_count = await _num_tokens_for_tools(tools, self.model_name, encoding)

return usage.RequestUsage(
input_tokens=message_token_count + tools_token_count,
)


@deprecated(
'`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
Expand Down Expand Up @@ -1908,6 +1938,36 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
assert_never(item)
return responses.EasyInputMessageParam(role='user', content=content)

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Count the number of tokens in the given messages."""
if self.system != 'openai':
raise NotImplementedError('Token counting is only supported for OpenAI system.')

try:
encoding = await _utils.run_in_executor(tiktoken.encoding_for_model, self.model_name)
except KeyError as e:
raise ValueError(
f'The model {self.model_name!r} is not supported by tiktoken',
) from e

model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters)
_, openai_messages = await self._map_messages(
messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
message_token_count = await _num_tokens_from_messages(openai_messages, self.model_name, encoding)

# Count tokens for tools/functions
tools = self._get_tools(model_request_parameters)
tools_token_count = await _num_tokens_for_tools(tools, self.model_name, encoding)
return usage.RequestUsage(
input_tokens=message_token_count + tools_token_count,
)


@dataclass
class OpenAIStreamedResponse(StreamedResponse):
Expand Down Expand Up @@ -2680,3 +2740,153 @@ def _map_mcp_call(
provider_name=provider_name,
),
)


async def _num_tokens_from_messages( # noqa: C901
messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam],
model: OpenAIModelName,
encoding: tiktoken.Encoding,
) -> int:
"""Return the number of tokens used by a list of messages."""
if 'gpt-5' in model:
tokens_per_message = 3
tokens_per_name = 1
final_primer = 2 # "reverse engineered" based on test cases
else:
# Adapted from https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the cookbook again, I think we should also try to implement support for counting the tokens of tool definitions:

https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#7-counting-tokens-for-chat-completions-with-tool-calls

tokens_per_message = 3
tokens_per_name = 1
final_primer = 3 # every reply is primed with <|start|>assistant<|message|>

num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
if (key == 'content' or key == 'role') and isinstance(value, str):
num_tokens += len(encoding.encode(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this (or the get_encoding call further up?) could download a large file, but tiktoken is sync not async, we should wrap the call that may do a download in _utils.run_in_executor to run it in a thread

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The methods which download the encoding file are tiktoken.encoding_for_model and tiktoken.get_encoding. So they would be wrapped with _utils.run_in_executor and then awaited? And _num_tokens_from_messages would become async?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wirthual Yep exactly. Note that I'm also using tiktoken in #3252, so you can see there how I did the run_in_executor thing.

elif key == 'tool_calls' and isinstance(value, list):
# Chat Completions API: list of ChatCompletionToolCallParam
# Responses API: list of ResponseFunctionToolCallParam
for tool_call in value: # pyright: ignore[reportUnknownVariableType]
if isinstance(tool_call, dict):
# Both ChatCompletionToolCallParam and ResponseFunctionToolCallParam have 'function' field
num_tokens += 4
function = tool_call.get('function', {}) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if function and isinstance(function, dict):
# Both have 'name' and 'arguments' fields
num_tokens += 1
if 'name' in function and isinstance(function['name'], str):
num_tokens += len(encoding.encode(function['name']))
if 'arguments' in function and isinstance(function['arguments'], str):
num_tokens += len(encoding.encode(function['arguments']))
elif key == 'name':
num_tokens += tokens_per_name
elif key == 'content' and isinstance(value, list):
# Handle list content (multimodal messages)
# Chat Completions API: list of ChatCompletionContentPartParam
# Responses API: list of ResponseInputContentParam
for content_part in value: # pyright: ignore[reportUnknownVariableType]
if isinstance(content_part, dict):
# ChatCompletionContentPartTextParam has 'text' field
# ResponseInputTextParam has 'text' field
if 'text' in content_part and isinstance(content_part['text'], str):
num_tokens += len(encoding.encode(content_part['text']))
# Note: Images, audio, files are not tokenized as text
# They have their own token costs handled by the API

num_tokens += final_primer
return num_tokens


async def _num_tokens_for_tools(
tools: list[chat.ChatCompletionToolParam] | list[responses.FunctionToolParam],
model: OpenAIModelName,
encoding: tiktoken.Encoding,
) -> int:
"""Return the number of tokens used by a list of tools.

Based on the OpenAI token counting approach for function calling.
Supports both Chat Completions API tools (ChatCompletionToolParam) and
Responses API tools (FunctionToolParam).
"""
# Initialize function settings to 0
func_init = 0
prop_init = 0
prop_key = 0
enum_init = 0
enum_item = 0
func_end = 0

if 'gpt-4o' in model or 'gpt-4o-mini' in model:
# Set function settings for gpt-4o models
func_init = 7
prop_init = 3
prop_key = 3
enum_init = -3
enum_item = 3
func_end = 12
elif 'gpt-3.5-turbo' in model or 'gpt-4' in model:
# Set function settings for gpt-3.5-turbo and gpt-4 models
func_init = 10
prop_init = 3
prop_key = 3
enum_init = -3
enum_item = 3
func_end = 12
else:
# Default to gpt-4o settings for unknown models
func_init = 7
prop_init = 3
prop_key = 3
enum_init = -3
enum_item = 3
func_end = 12

func_token_count = 0
if len(tools) > 0:
for tool in tools:
func_token_count += func_init # Add tokens for start of each function

# Handle both ChatCompletionToolParam and FunctionToolParam structures
# ChatCompletionToolParam: {'type': 'function', 'function': {'name': ..., 'description': ..., 'parameters': ...}}
# FunctionToolParam: {'type': 'function', 'name': ..., 'description': ..., 'parameters': ...}
if 'function' in tool:
# ChatCompletionToolParam format
function = tool['function']
f_name = str(function.get('name', ''))
f_desc = str(function.get('description', '') or '')
parameters = function.get('parameters')
else:
# FunctionToolParam format (Responses API)
f_name = str(tool.get('name', ''))
f_desc = str(tool.get('description', '') or '')
parameters = tool.get('parameters')

if f_desc.endswith('.'):
f_desc = f_desc[:-1]
line = f'{f_name}:{f_desc}'
func_token_count += len(encoding.encode(line)) # Add tokens for function name and description

if parameters and isinstance(parameters, dict):
properties_raw = parameters.get('properties', {})
if properties_raw and isinstance(properties_raw, dict) and len(properties_raw) > 0: # pyright: ignore[reportUnknownArgumentType]
func_token_count += prop_init # Add tokens for start of properties
for key, prop_value in properties_raw.items(): # pyright: ignore[reportUnknownVariableType]
if not isinstance(prop_value, dict):
continue
func_token_count += prop_key # Add tokens for each property
p_name = str(key) # pyright: ignore[reportUnknownArgumentType]
p_type = str(prop_value.get('type', '') or '') # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
p_desc = str(prop_value.get('description', '') or '') # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
if 'enum' in prop_value:
func_token_count += enum_init # Add tokens if property has enum list
for item in prop_value['enum']: # pyright: ignore[reportUnknownVariableType]
func_token_count += enum_item
func_token_count += len(encoding.encode(str(item))) # pyright: ignore[reportUnknownArgumentType]
if p_desc.endswith('.'):
p_desc = p_desc[:-1]
line = f'{p_name}:{p_type}:{p_desc}'
func_token_count += len(encoding.encode(line))
func_token_count += func_end

return func_token_count
13 changes: 11 additions & 2 deletions pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,17 @@ class UsageLimits:
"""The maximum number of tokens allowed in requests and responses combined."""
count_tokens_before_request: bool = False
"""If True, perform a token counting pass before sending the request to the model,
to enforce `request_tokens_limit` ahead of time. This may incur additional overhead
(from calling the model's `count_tokens` API before making the actual request) and is disabled by default."""
to enforce `input_tokens_limit` ahead of time. This may incur additional overhead
(from calling the model's `count_tokens` method before making the actual request) and is disabled by default.

Supported by:

- [`OpenAIChatModel`][pydantic_ai.models.openai.OpenAIChatModel] and
[`OpenAIResponsesModel`][pydantic_ai.models.openai.OpenAIResponsesModel] (only for OpenAI models)
- [`AnthropicModel`][pydantic_ai.models.anthropic.AnthropicModel] (excluding Bedrock client)
- [`GoogleModel`][pydantic_ai.models.google.GoogleModel]
- [`BedrockModel`][pydantic_ai.models.bedrock.BedrockModel] (including Anthropic models)
"""

@property
@deprecated('`request_tokens_limit` is deprecated, use `input_tokens_limit` instead')
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ dependencies = [
# WARNING if you add optional groups, please update docs/install.md
logfire = ["logfire[httpx]>=3.14.1"]
# Models
openai = ["openai>=2.11.0"]
openai = ["openai>=2.11.0","tiktoken>=0.12.0"]
cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.51.0"]
Expand Down
Loading
Loading