Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 94 additions & 45 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few TODOs in the file that we should address at the same time:

  • # TODO: Should be able to use json_schema
  • # TODO: Port to native "manual JSON" mode

These are related to an old approach for handling output_tools by passing the schemas to the API as user text parts, which I don't think we need anymore, so we should implement the tool, native, and prompted modes the way the other models do.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import assert_never

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
from .._output import OutputObjectDefinition
from .._run_context import RunContext
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
from ..exceptions import ModelAPIError, UserError
Expand Down Expand Up @@ -62,6 +63,7 @@
Mistral,
OptionalNullable as MistralOptionalNullable,
ReferenceChunk as MistralReferenceChunk,
ResponseFormat as MistralResponseFormat,
TextChunk as MistralTextChunk,
ThinkChunk as MistralThinkChunk,
ToolChoiceEnum as MistralToolChoiceEnum,
Expand All @@ -70,6 +72,7 @@
ChatCompletionResponse as MistralChatCompletionResponse,
CompletionEvent as MistralCompletionEvent,
FinishReason as MistralFinishReason,
JSONSchema as MistralJSONSchema,
Messages as MistralMessages,
SDKError,
Tool as MistralTool,
Expand Down Expand Up @@ -215,6 +218,32 @@ async def request_stream(
async with response:
yield await self._process_streamed_response(response, model_request_parameters)

def _get_response_format(self, model_request_parameters: ModelRequestParameters) -> MistralResponseFormat | None:
"""Get the response format for Mistral API based on output mode.

Returns None if no special format is needed.
"""
if model_request_parameters.output_mode == 'native':
# Use native JSON schema mode
output_object = model_request_parameters.output_object
assert output_object is not None
json_schema = self._map_json_schema(output_object)
return MistralResponseFormat(type='json_schema', json_schema=json_schema)
elif model_request_parameters.output_mode == 'prompted' and not model_request_parameters.function_tools:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need not model_request_parameters.function_tool

# Use JSON object mode (without schema)
return MistralResponseFormat(type='json_object')
else:
return None

def _map_json_schema(self, o: OutputObjectDefinition) -> MistralJSONSchema:
"""Map OutputObjectDefinition to Mistral JSONSchema format."""
return MistralJSONSchema(
name=o.name or 'output',
schema_definition=o.json_schema,
description=o.description or UNSET,
strict=o.strict if o.strict is not None else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just strict=o.strict right?

)

async def _completions_create(
self,
messages: list[ModelMessage],
Expand All @@ -227,13 +256,25 @@ async def _completions_create(
if model_request_parameters.builtin_tools:
raise UserError('Mistral does not support built-in tools')

# Determine the response format based on output mode
response_format = self._get_response_format(model_request_parameters)

# When using native JSON schema mode, don't use tool-based output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to branch on this: if model_request_parameters.output_mode == 'native', model_request_parameters.output_tools will be empty anyway. Note that we don't have logic like this in the other models

if model_request_parameters.output_mode == 'native':
tools = self._map_function_tools_only(model_request_parameters) or UNSET
tool_choice = self._get_tool_choice_for_functions_only(model_request_parameters)
else:
tools = self._map_function_and_output_tools_definition(model_request_parameters) or UNSET
tool_choice = self._get_tool_choice(model_request_parameters)

try:
response = await self.client.chat.complete_async(
model=str(self._model_name),
messages=self._map_messages(messages, model_request_parameters),
n=1,
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
tool_choice=self._get_tool_choice(model_request_parameters),
tools=tools,
tool_choice=tool_choice,
response_format=response_format,
stream=False,
max_tokens=model_settings.get('max_tokens', UNSET),
temperature=model_settings.get('temperature', UNSET),
Expand All @@ -258,57 +299,41 @@ async def _stream_completions_create(
model_request_parameters: ModelRequestParameters,
) -> MistralEventStreamAsync[MistralCompletionEvent]:
"""Create a streaming completion request to the Mistral model."""
response: MistralEventStreamAsync[MistralCompletionEvent] | None
mistral_messages = self._map_messages(messages, model_request_parameters)

# TODO(Marcelo): We need to replace the current MistralAI client to use the beta client.
# See https://docs.mistral.ai/agents/connectors/websearch/ to support web search.
if model_request_parameters.builtin_tools:
raise UserError('Mistral does not support built-in tools')

if model_request_parameters.function_tools:
# Function Calling
response = await self.client.chat.stream_async(
model=str(self._model_name),
messages=mistral_messages,
n=1,
tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET,
tool_choice=self._get_tool_choice(model_request_parameters),
temperature=model_settings.get('temperature', UNSET),
top_p=model_settings.get('top_p', 1),
max_tokens=model_settings.get('max_tokens', UNSET),
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
presence_penalty=model_settings.get('presence_penalty'),
frequency_penalty=model_settings.get('frequency_penalty'),
stop=model_settings.get('stop_sequences', None),
http_headers={'User-Agent': get_user_agent()},
)

elif model_request_parameters.output_tools:
# TODO: Port to native "manual JSON" mode
# Json Mode
parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools]
user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
mistral_messages.append(user_output_format_message)
mistral_messages = self._map_messages(messages, model_request_parameters)

response = await self.client.chat.stream_async(
model=str(self._model_name),
messages=mistral_messages,
response_format={
'type': 'json_object'
}, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9
stream=True,
http_headers={'User-Agent': get_user_agent()},
)
# Determine the response format based on output mode
response_format = self._get_response_format(model_request_parameters)

# When using native JSON schema mode, don't use tool-based output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above; I don't think we need this. And if we do, we should dedupe it between the request/request_stream methods :)

if model_request_parameters.output_mode == 'native':
tools = self._map_function_tools_only(model_request_parameters) or UNSET
tool_choice = self._get_tool_choice_for_functions_only(model_request_parameters)
else:
# Stream Mode
response = await self.client.chat.stream_async(
model=str(self._model_name),
messages=mistral_messages,
stream=True,
http_headers={'User-Agent': get_user_agent()},
)
tools = self._map_function_and_output_tools_definition(model_request_parameters) or UNSET
tool_choice = self._get_tool_choice(model_request_parameters)

response = await self.client.chat.stream_async(
model=str(self._model_name),
messages=mistral_messages,
n=1,
tools=tools,
tool_choice=tool_choice,
response_format=response_format,
temperature=model_settings.get('temperature', UNSET),
top_p=model_settings.get('top_p', 1),
max_tokens=model_settings.get('max_tokens', UNSET),
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
presence_penalty=model_settings.get('presence_penalty'),
frequency_penalty=model_settings.get('frequency_penalty'),
stop=model_settings.get('stop_sequences', None),
stream=True,
http_headers={'User-Agent': get_user_agent()},
)
assert response, 'A unexpected empty response from Mistral.'
return response

Expand Down Expand Up @@ -344,6 +369,30 @@ def _map_function_and_output_tools_definition(
]
return tools if tools else None

def _map_function_tools_only(self, model_request_parameters: ModelRequestParameters) -> list[MistralTool] | None:
"""Map only function tools (not output tools) to MistralTool format.

This is used when output is handled via native JSON schema mode instead of tools.
"""
tools = [
MistralTool(
function=MistralFunction(
name=r.name, parameters=r.parameters_json_schema, description=r.description or ''
)
)
for r in model_request_parameters.function_tools
]
return tools if tools else None

def _get_tool_choice_for_functions_only(
self, model_request_parameters: ModelRequestParameters
) -> MistralToolChoiceEnum | None:
"""Get tool choice when only function tools are used (not output tools)."""
if not model_request_parameters.function_tools:
return None
# When using native output mode, we don't force tool use since output is handled separately
return 'auto'

def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
assert response.choices, 'Unexpected empty response choice.'
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/profiles/mistral.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations as _annotations

from . import ModelProfile
from .openai import OpenAIJsonSchemaTransformer


def mistral_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for a Mistral model."""
return None
return ModelProfile(
json_schema_transformer=OpenAIJsonSchemaTransformer,
supports_json_schema_output=True,
supports_json_object_output=True,
)
Loading