Skip to content

Commit

Permalink
o1 support for agent chat, and validate model capabilities (#4397)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Nov 27, 2024
1 parent 3058baf commit 52790a8
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AgentMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
Expand Down Expand Up @@ -113,7 +114,10 @@ class AssistantAgent(BaseChatAgent):
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client)
response = await agent.on_messages(
Expand Down Expand Up @@ -144,7 +148,10 @@ async def get_current_time() -> str:
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
model_client = OpenAIChatCompletionClient(
model="gpt-4o",
# api_key = "your_openai_api_key"
)
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
await Console(
Expand All @@ -156,6 +163,39 @@ async def main() -> None:
asyncio.run(main())
The following example shows how to use `o1-mini` model with the assistant agent.
.. code-block:: python
import asyncio
from autogen_core.base import CancellationToken
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
async def main() -> None:
model_client = OpenAIChatCompletionClient(
model="o1-mini",
# api_key = "your_openai_api_key"
)
# The system message is not supported by the o1 series model.
agent = AssistantAgent(name="assistant", model_client=model_client, system_message=None)
response = await agent.on_messages(
[TextMessage(content="What is the capital of France?", source="user")], CancellationToken()
)
print(response)
asyncio.run(main())
.. note::
The `o1-preview` and `o1-mini` models do not support system message and function calling.
So the `system_message` should be set to `None` and the `tools` and `handoffs` should not be set.
See `o1 beta limitations <https://platform.openai.com/docs/guides/reasoning#beta-limitations>`_ for more details.
"""

def __init__(
Expand All @@ -166,13 +206,19 @@ def __init__(
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[Handoff | str] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
if system_message is None:
self._system_messages = []
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
Expand All @@ -192,6 +238,8 @@ def __init__(
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, Handoff] = {}
if handoffs is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling, which is needed for handoffs.")
for handoff in handoffs:
if isinstance(handoff, str):
handoff = Handoff(target=handoff)
Expand Down Expand Up @@ -229,6 +277,8 @@ async def on_messages_stream(
) -> AsyncGenerator[AgentMessage | Response, None]:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
Expand Down
24 changes: 24 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,27 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
assert len(result.messages) == 2


@pytest.mark.asyncio
async def test_invalid_model_capabilities() -> None:
model = "random-model"
model_client = OpenAIChatCompletionClient(
model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False}
)

with pytest.raises(ValueError):
agent = AssistantAgent(
name="assistant",
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)

with pytest.raises(ValueError):
agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"])

with pytest.raises(ValueError):
agent = AssistantAgent(name="assistant", model_client=model_client)
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Based on: https://platform.openai.com/docs/models/continuous-model-upgrades
# This is a moving target, so correctness is checked by the model value returned by openai against expected values at runtime``
_MODEL_POINTERS = {
"o1-preview": "o1-preview-2024-09-12",
"o1-mini": "o1-mini-2024-09-12",
"gpt-4o": "gpt-4o-2024-08-06",
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
Expand All @@ -16,6 +18,16 @@
}

_MODEL_CAPABILITIES: Dict[str, ModelCapabilities] = {
"o1-preview-2024-09-12": {
"vision": False,
"function_calling": False,
"json_output": False,
},
"o1-mini-2024-09-12": {
"vision": False,
"function_calling": False,
"json_output": False,
},
"gpt-4o-2024-08-06": {
"vision": True,
"function_calling": True,
Expand Down Expand Up @@ -89,6 +101,8 @@
}

_MODEL_TOKEN_LIMITS: Dict[str, int] = {
"o1-preview-2024-09-12": 128000,
"o1-mini-2024-09-12": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-mini-2024-07-18": 128000,
Expand Down

0 comments on commit 52790a8

Please sign in to comment.