Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions openjudge/models/base_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from openjudge.models.schema.oai.response import ChatResponse

TOOL_CHOICE_MODES = ["auto", "none", "any", "required"]
TOOL_CHOICE_MODES = {"auto", "none", "any", "required"}


class BaseChatModel(ABC):
Expand Down Expand Up @@ -113,13 +113,34 @@ def _validate_tool_choice(
raise TypeError(
f"tool_choice must be str, got {type(tool_choice)}",
)

tool_choice = tool_choice.strip()
if not tool_choice:
raise ValueError("`tool_choice` must be a non-empty string.")

if tool_choice in TOOL_CHOICE_MODES:
return

available_functions = [tool["function"]["name"] for tool in tools] if tools else []
if not tools:
raise ValueError(
f"Tool choice '{tool_choice}' is not a built-in mode ({', '.join(TOOL_CHOICE_MODES)}) "
"and no tools were provided."
)

available_functions = set()
for i, tool in enumerate(tools):
if not isinstance(tool, dict):
raise TypeError(f"Tool at index {i} is not a dictionary.")
func = tool.get("function")
if not isinstance(func, dict):
raise TypeError(f"Tool at index {i} missing or invalid 'function' field.")
name = func.get("name")
if not isinstance(name, str):
raise TypeError(f"Tool function name at index {i} is not a string.")
available_functions.add(name)

if tool_choice not in available_functions:
all_options = TOOL_CHOICE_MODES + available_functions
all_options = sorted(TOOL_CHOICE_MODES | available_functions)
raise ValueError(
f"Invalid tool_choice '{tool_choice}'. " f"Available options: {', '.join(sorted(all_options))}",
f"Invalid tool_choice '{tool_choice}'. " f"Available options: {', '.join(all_options)}",
)
42 changes: 28 additions & 14 deletions openjudge/models/openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""OpenAI Client."""
import copy
import os
from typing import Any, AsyncGenerator, Callable, Dict, Literal, Type

Expand All @@ -13,28 +14,36 @@
from openjudge.utils.utils import repair_and_load_json


def _format_audio_data_for_qwen_omni(messages: list[dict | ChatMessage]) -> None:
def _format_audio_data_for_qwen_omni(messages: list[dict | ChatMessage]) -> list[dict]:
"""Qwen-omni uses OpenAI-compatible API but requires different audio
data format than OpenAI with "data:;base64," prefix.
Refer to `Qwen-omni documentation
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`_
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`
for more details.

Args:
messages (`list[dict]`):
The list of message dictionaries from OpenAI formatter.
"""
format_data = []
for msg in messages:
msg_dict = msg.to_dict() if isinstance(msg, ChatMessage) else msg
if isinstance(msg_dict.get("content"), list):
for block in msg_dict["content"]:
if (
isinstance(block, dict)
and "input_audio" in block
and isinstance(block["input_audio"].get("data"), str)
):
if not block["input_audio"]["data"].startswith("http"):
try:
msg_copy = copy.deepcopy(msg)
msg_dict = msg_copy.to_dict() if isinstance(msg_copy, ChatMessage) else msg_copy
if isinstance(msg_dict.get("content"), list):
for block in msg_dict["content"]:
if (
isinstance(block, dict)
and "input_audio" in block
and isinstance(block["input_audio"].get("data"), str)
and not block["input_audio"]["data"].startswith("http")
):
block["input_audio"]["data"] = "data:;base64," + block["input_audio"]["data"]
format_data.append(msg_dict)
except Exception as e:
logger.error(f"Failed to format audio data: {type(e).__name__}: {e}", exc_info=True)
format_data.append(msg.to_dict() if isinstance(msg, ChatMessage) else msg)
return format_data


class OpenAIChatModel(BaseChatModel):
Expand Down Expand Up @@ -150,7 +159,7 @@ async def achat(

# Qwen-omni requires different base64 audio format from openai
if "omni" in self.model.lower():
_format_audio_data_for_qwen_omni(messages)
messages = _format_audio_data_for_qwen_omni(messages)

kwargs = {
"model": self.model,
Expand Down Expand Up @@ -187,9 +196,14 @@ async def achat(
kwargs.pop("tool_choice", None)

if "qwen" in self.model:
structured_model = {"type": "json_object"} # type: ignore
logger.warning(
"Qwen models do not support Pydantic structured output via `response_format`. "
"Update the unstructured JSON mode with `response_format={'type': 'json_object'}`."
)
kwargs["response_format"] = {"type": "json_object"}
else:
kwargs["response_format"] = structured_model

kwargs["response_format"] = structured_model
if not self.stream:
response = await self.client.chat.completions.parse(**kwargs)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def test_qwen_omni_audio_formatting(self):
]

# Apply the transformation
_format_audio_data_for_qwen_omni(messages)
messages = _format_audio_data_for_qwen_omni(messages)

# Check that the data was formatted correctly
assert messages[0]["content"][0]["input_audio"]["data"].startswith(
Expand Down