From 52790a8de74bedf41e7b5279c02ffdc1c30770ac Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 27 Nov 2024 10:45:51 -0800 Subject: [PATCH] o1 support for agent chat, and validate model capabilities (#4397) --- .../agents/_assistant_agent.py | 58 +++++++++++++++++-- .../tests/test_assistant_agent.py | 24 ++++++++ .../autogen_ext/models/_openai/_model_info.py | 14 +++++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 0870a6c2f3b..1edf86f0061 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -23,6 +23,7 @@ AgentMessage, ChatMessage, HandoffMessage, + MultiModalMessage, TextMessage, ToolCallMessage, ToolCallResultMessage, @@ -113,7 +114,10 @@ class AssistantAgent(BaseChatAgent): async def main() -> None: - model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) agent = AssistantAgent(name="assistant", model_client=model_client) response = await agent.on_messages( @@ -144,7 +148,10 @@ async def get_current_time() -> str: async def main() -> None: - model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) await Console( @@ -156,6 +163,39 @@ async def main() -> None: asyncio.run(main()) + + The following example shows how to use `o1-mini` model with the assistant agent. + + .. code-block:: python + + import asyncio + from autogen_core.base import CancellationToken + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="o1-mini", + # api_key = "your_openai_api_key" + ) + # The system message is not supported by the o1 series model. + agent = AssistantAgent(name="assistant", model_client=model_client, system_message=None) + + response = await agent.on_messages( + [TextMessage(content="What is the capital of France?", source="user")], CancellationToken() + ) + print(response) + + + asyncio.run(main()) + + .. note:: + + The `o1-preview` and `o1-mini` models do not support system message and function calling. + So the `system_message` should be set to `None` and the `tools` and `handoffs` should not be set. + See `o1 beta limitations `_ for more details. """ def __init__( @@ -166,13 +206,19 @@ def __init__( tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, handoffs: List[Handoff | str] | None = None, description: str = "An agent that provides assistance with ability to use tools.", - system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + system_message: str + | None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", ): super().__init__(name=name, description=description) self._model_client = model_client - self._system_messages = [SystemMessage(content=system_message)] + if system_message is None: + self._system_messages = [] + else: + self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] if tools is not None: + if model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): self._tools.append(tool) @@ -192,6 +238,8 @@ def __init__( self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, Handoff] = {} if handoffs is not None: + if model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): handoff = Handoff(target=handoff) @@ -229,6 +277,8 @@ async def on_messages_stream( ) -> AsyncGenerator[AgentMessage | Response, None]: # Add messages to the model context. for msg in messages: + if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False: + raise ValueError("The model does not support vision.") self._model_context.append(UserMessage(content=msg.content, source=msg.source)) # Inner messages. diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 98ee8c3990d..086ea62ae42 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -233,3 +233,27 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) assert len(result.messages) == 2 + + +@pytest.mark.asyncio +async def test_invalid_model_capabilities() -> None: + model = "random-model" + model_client = OpenAIChatCompletionClient( + model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False} + ) + + with pytest.raises(ValueError): + agent = AssistantAgent( + name="assistant", + model_client=model_client, + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + ) + + with pytest.raises(ValueError): + agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"]) + + with pytest.raises(ValueError): + agent = AssistantAgent(name="assistant", model_client=model_client) + # Generate a random base64 image. + img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" + await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py index aea2bfb5d1c..3a837915f42 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/_openai/_model_info.py @@ -5,6 +5,8 @@ # Based on: https://platform.openai.com/docs/models/continuous-model-upgrades # This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime`` _MODEL_POINTERS = { + "o1-preview": "o1-preview-2024-09-12", + "o1-mini": "o1-mini-2024-09-12", "gpt-4o": "gpt-4o-2024-08-06", "gpt-4o-mini": "gpt-4o-mini-2024-07-18", "gpt-4-turbo": "gpt-4-turbo-2024-04-09", @@ -16,6 +18,16 @@ } _MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = { + "o1-preview-2024-09-12": { + "vision": False, + "function_calling": False, + "json_output": False, + }, + "o1-mini-2024-09-12": { + "vision": False, + "function_calling": False, + "json_output": False, + }, "gpt-4o-2024-08-06": { "vision": True, "function_calling": True, @@ -89,6 +101,8 @@ } _MODEL_TOKEN_LIMITS: Dict[str, int] = { + "o1-preview-2024-09-12": 128000, + "o1-mini-2024-09-12": 128000, "gpt-4o-2024-08-06": 128000, "gpt-4o-2024-05-13": 128000, "gpt-4o-mini-2024-07-18": 128000,