-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Mistral: implement native output mode for structured outputs #3662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5c48fc9
5f9bba8
f281dc3
dc835b3
7145d7f
ef9669c
f130196
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| from typing_extensions import assert_never | ||
|
|
||
| from .. import ModelHTTPError, UnexpectedModelBehavior, _utils | ||
| from .._output import OutputObjectDefinition | ||
| from .._run_context import RunContext | ||
| from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime | ||
| from ..exceptions import ModelAPIError, UserError | ||
|
|
@@ -62,6 +63,7 @@ | |
| Mistral, | ||
| OptionalNullable as MistralOptionalNullable, | ||
| ReferenceChunk as MistralReferenceChunk, | ||
| ResponseFormat as MistralResponseFormat, | ||
| TextChunk as MistralTextChunk, | ||
| ThinkChunk as MistralThinkChunk, | ||
| ToolChoiceEnum as MistralToolChoiceEnum, | ||
|
|
@@ -70,6 +72,7 @@ | |
| ChatCompletionResponse as MistralChatCompletionResponse, | ||
| CompletionEvent as MistralCompletionEvent, | ||
| FinishReason as MistralFinishReason, | ||
| JSONSchema as MistralJSONSchema, | ||
| Messages as MistralMessages, | ||
| SDKError, | ||
| Tool as MistralTool, | ||
|
|
@@ -215,6 +218,32 @@ async def request_stream( | |
| async with response: | ||
| yield await self._process_streamed_response(response, model_request_parameters) | ||
|
|
||
| def _get_response_format(self, model_request_parameters: ModelRequestParameters) -> MistralResponseFormat | None: | ||
| """Get the response format for Mistral API based on output mode. | ||
|
|
||
| Returns None if no special format is needed. | ||
| """ | ||
| if model_request_parameters.output_mode == 'native': | ||
| # Use native JSON schema mode | ||
| output_object = model_request_parameters.output_object | ||
| assert output_object is not None | ||
| json_schema = self._map_json_schema(output_object) | ||
| return MistralResponseFormat(type='json_schema', json_schema=json_schema) | ||
| elif model_request_parameters.output_mode == 'prompted' and not model_request_parameters.function_tools: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need |
||
| # Use JSON object mode (without schema) | ||
| return MistralResponseFormat(type='json_object') | ||
| else: | ||
| return None | ||
|
|
||
| def _map_json_schema(self, o: OutputObjectDefinition) -> MistralJSONSchema: | ||
| """Map OutputObjectDefinition to Mistral JSONSchema format.""" | ||
| return MistralJSONSchema( | ||
| name=o.name or 'output', | ||
| schema_definition=o.json_schema, | ||
| description=o.description or UNSET, | ||
| strict=o.strict if o.strict is not None else None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just |
||
| ) | ||
|
|
||
| async def _completions_create( | ||
| self, | ||
| messages: list[ModelMessage], | ||
|
|
@@ -227,13 +256,25 @@ async def _completions_create( | |
| if model_request_parameters.builtin_tools: | ||
| raise UserError('Mistral does not support built-in tools') | ||
|
|
||
| # Determine the response format based on output mode | ||
| response_format = self._get_response_format(model_request_parameters) | ||
|
|
||
| # When using native JSON schema mode, don't use tool-based output | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to branch on this: if |
||
| if model_request_parameters.output_mode == 'native': | ||
| tools = self._map_function_tools_only(model_request_parameters) or UNSET | ||
| tool_choice = self._get_tool_choice_for_functions_only(model_request_parameters) | ||
| else: | ||
| tools = self._map_function_and_output_tools_definition(model_request_parameters) or UNSET | ||
| tool_choice = self._get_tool_choice(model_request_parameters) | ||
|
|
||
| try: | ||
| response = await self.client.chat.complete_async( | ||
| model=str(self._model_name), | ||
| messages=self._map_messages(messages, model_request_parameters), | ||
| n=1, | ||
| tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, | ||
| tool_choice=self._get_tool_choice(model_request_parameters), | ||
| tools=tools, | ||
| tool_choice=tool_choice, | ||
| response_format=response_format, | ||
| stream=False, | ||
| max_tokens=model_settings.get('max_tokens', UNSET), | ||
| temperature=model_settings.get('temperature', UNSET), | ||
|
|
@@ -258,57 +299,41 @@ async def _stream_completions_create( | |
| model_request_parameters: ModelRequestParameters, | ||
| ) -> MistralEventStreamAsync[MistralCompletionEvent]: | ||
| """Create a streaming completion request to the Mistral model.""" | ||
| response: MistralEventStreamAsync[MistralCompletionEvent] | None | ||
| mistral_messages = self._map_messages(messages, model_request_parameters) | ||
|
|
||
| # TODO(Marcelo): We need to replace the current MistralAI client to use the beta client. | ||
| # See https://docs.mistral.ai/agents/connectors/websearch/ to support web search. | ||
| if model_request_parameters.builtin_tools: | ||
| raise UserError('Mistral does not support built-in tools') | ||
|
|
||
| if model_request_parameters.function_tools: | ||
| # Function Calling | ||
| response = await self.client.chat.stream_async( | ||
| model=str(self._model_name), | ||
| messages=mistral_messages, | ||
| n=1, | ||
| tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, | ||
| tool_choice=self._get_tool_choice(model_request_parameters), | ||
| temperature=model_settings.get('temperature', UNSET), | ||
| top_p=model_settings.get('top_p', 1), | ||
| max_tokens=model_settings.get('max_tokens', UNSET), | ||
| timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), | ||
| presence_penalty=model_settings.get('presence_penalty'), | ||
| frequency_penalty=model_settings.get('frequency_penalty'), | ||
| stop=model_settings.get('stop_sequences', None), | ||
| http_headers={'User-Agent': get_user_agent()}, | ||
| ) | ||
|
|
||
| elif model_request_parameters.output_tools: | ||
| # TODO: Port to native "manual JSON" mode | ||
| # Json Mode | ||
| parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools] | ||
| user_output_format_message = self._generate_user_output_format(parameters_json_schemas) | ||
| mistral_messages.append(user_output_format_message) | ||
| mistral_messages = self._map_messages(messages, model_request_parameters) | ||
|
|
||
| response = await self.client.chat.stream_async( | ||
| model=str(self._model_name), | ||
| messages=mistral_messages, | ||
| response_format={ | ||
| 'type': 'json_object' | ||
| }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 | ||
| stream=True, | ||
| http_headers={'User-Agent': get_user_agent()}, | ||
| ) | ||
| # Determine the response format based on output mode | ||
| response_format = self._get_response_format(model_request_parameters) | ||
|
|
||
| # When using native JSON schema mode, don't use tool-based output | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above; I don't think we need this. And if we do, we should dedupe it between the request/request_stream methods :) |
||
| if model_request_parameters.output_mode == 'native': | ||
| tools = self._map_function_tools_only(model_request_parameters) or UNSET | ||
| tool_choice = self._get_tool_choice_for_functions_only(model_request_parameters) | ||
| else: | ||
| # Stream Mode | ||
| response = await self.client.chat.stream_async( | ||
| model=str(self._model_name), | ||
| messages=mistral_messages, | ||
| stream=True, | ||
| http_headers={'User-Agent': get_user_agent()}, | ||
| ) | ||
| tools = self._map_function_and_output_tools_definition(model_request_parameters) or UNSET | ||
| tool_choice = self._get_tool_choice(model_request_parameters) | ||
|
|
||
| response = await self.client.chat.stream_async( | ||
| model=str(self._model_name), | ||
| messages=mistral_messages, | ||
| n=1, | ||
| tools=tools, | ||
| tool_choice=tool_choice, | ||
| response_format=response_format, | ||
| temperature=model_settings.get('temperature', UNSET), | ||
| top_p=model_settings.get('top_p', 1), | ||
| max_tokens=model_settings.get('max_tokens', UNSET), | ||
| timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), | ||
| presence_penalty=model_settings.get('presence_penalty'), | ||
| frequency_penalty=model_settings.get('frequency_penalty'), | ||
| stop=model_settings.get('stop_sequences', None), | ||
| stream=True, | ||
| http_headers={'User-Agent': get_user_agent()}, | ||
| ) | ||
| assert response, 'A unexpected empty response from Mistral.' | ||
| return response | ||
|
|
||
|
|
@@ -344,6 +369,30 @@ def _map_function_and_output_tools_definition( | |
| ] | ||
| return tools if tools else None | ||
|
|
||
| def _map_function_tools_only(self, model_request_parameters: ModelRequestParameters) -> list[MistralTool] | None: | ||
| """Map only function tools (not output tools) to MistralTool format. | ||
|
|
||
| This is used when output is handled via native JSON schema mode instead of tools. | ||
| """ | ||
| tools = [ | ||
| MistralTool( | ||
| function=MistralFunction( | ||
| name=r.name, parameters=r.parameters_json_schema, description=r.description or '' | ||
| ) | ||
| ) | ||
| for r in model_request_parameters.function_tools | ||
| ] | ||
| return tools if tools else None | ||
|
|
||
| def _get_tool_choice_for_functions_only( | ||
| self, model_request_parameters: ModelRequestParameters | ||
| ) -> MistralToolChoiceEnum | None: | ||
| """Get tool choice when only function tools are used (not output tools).""" | ||
| if not model_request_parameters.function_tools: | ||
| return None | ||
| # When using native output mode, we don't force tool use since output is handled separately | ||
| return 'auto' | ||
|
|
||
| def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse: | ||
| """Process a non-streamed response, and prepare a message to return.""" | ||
| assert response.choices, 'Unexpected empty response choice.' | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,13 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| from . import ModelProfile | ||
| from .openai import OpenAIJsonSchemaTransformer | ||
|
|
||
|
|
||
| def mistral_model_profile(model_name: str) -> ModelProfile | None: | ||
| """Get the model profile for a Mistral model.""" | ||
| return None | ||
| return ModelProfile( | ||
| json_schema_transformer=OpenAIJsonSchemaTransformer, | ||
| supports_json_schema_output=True, | ||
| supports_json_object_output=True, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few
TODOs in the file that we should address at the same time:# TODO: Should be able to use json_schema# TODO: Port to native "manual JSON" modeThese are related to an old approach for handling
output_toolsby passing the schemas to the API as user text parts, which I don't think we need anymore, so we should implement thetool,native, andpromptedmodes the way the other models do.