Skip to content

Commit 8c25461

Browse files
author
Jean-Marc Le Roux
committed
Add the add_tools(), remove_tools() and remove_all_tools() methods for AssistantAgent
1 parent 8b05e03 commit 8c25461

File tree

1 file changed

+42
-19
lines changed

1 file changed

+42
-19
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,7 @@ def __init__(
182182
else:
183183
self._system_messages = [SystemMessage(content=system_message)]
184184
self._tools: List[Tool] = []
185-
if tools is not None:
186-
if model_client.capabilities["function_calling"] is False:
187-
raise ValueError("The model does not support function calling.")
188-
for tool in tools:
189-
if isinstance(tool, Tool):
190-
self._tools.append(tool)
191-
elif callable(tool):
192-
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
193-
description = tool.__doc__
194-
else:
195-
description = ""
196-
self._tools.append(FunctionTool(tool, description=description))
197-
else:
198-
raise ValueError(f"Unsupported tool type: {type(tool)}")
199-
# Check if tool names are unique.
200-
tool_names = [tool.name for tool in self._tools]
201-
if len(tool_names) != len(set(tool_names)):
202-
raise ValueError(f"Tool names must be unique: {tool_names}")
185+
self._model_context: List[LLMMessage] = []
203186
# Handoff tools.
204187
self._handoff_tools: List[Tool] = []
205188
self._handoffs: Dict[str, HandoffBase] = {}
@@ -214,6 +197,27 @@ def __init__(
214197
self._handoffs[handoff.name] = handoff
215198
else:
216199
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
200+
if tools is not None:
201+
self.add_tools(tools)
202+
203+
def add_tools(self, tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]) -> None:
204+
if self._model_client.capabilities["function_calling"] is False:
205+
raise ValueError("The model does not support function calling.")
206+
for tool in tools:
207+
if isinstance(tool, Tool):
208+
self._tools.append(tool)
209+
elif callable(tool):
210+
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
211+
description = tool.__doc__
212+
else:
213+
description = ""
214+
self._tools.append(FunctionTool(tool, description=description))
215+
else:
216+
raise ValueError(f"Unsupported tool type: {type(tool)}")
217+
# Check if tool names are unique.
218+
tool_names = [tool.name for tool in self._tools]
219+
if len(tool_names) != len(set(tool_names)):
220+
raise ValueError(f"Tool names must be unique: {tool_names}")
217221
# Check if handoff tool names are unique.
218222
handoff_tool_names = [tool.name for tool in self._handoff_tools]
219223
if len(handoff_tool_names) != len(set(handoff_tool_names)):
@@ -223,7 +227,26 @@ def __init__(
223227
raise ValueError(
224228
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
225229
)
226-
self._model_context: List[LLMMessage] = []
230+
231+
def remove_all_tools(self) -> None:
232+
"""Remove all tools."""
233+
self._tools = []
234+
235+
def remove_tools(self, tool_names: List[str]) -> None:
236+
"""Remove tools by name."""
237+
for name in tool_names:
238+
for tool in self._tools:
239+
if tool.name == name:
240+
self._tools.remove(tool)
241+
break
242+
for tool in self._handoff_tools:
243+
if tool.name == name:
244+
self._handoff_tools.remove(tool)
245+
break
246+
for handoff in self._handoffs.values():
247+
if handoff.name == name:
248+
self._handoffs.pop(handoff.name)
249+
break
227250

228251
@property
229252
def produced_message_types(self) -> List[type[ChatMessage]]:

0 commit comments

Comments
 (0)