Skip to content

Commit

Permalink
Make auto reply method pluggable (#1177)
Browse files Browse the repository at this point in the history
* Make auto reply method pluggable

* allow richer trigger types

* test list
  • Loading branch information
sonichi authored Aug 7, 2023
1 parent 2208dfb commit a603e6d
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 82 deletions.
3 changes: 2 additions & 1 deletion flaml/autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from .responsive_agent import ResponsiveAgent
from .assistant_agent import AssistantAgent
from .user_proxy_agent import UserProxyAgent
from .groupchat import GroupChatManager
from .groupchat import GroupChat, GroupChatManager

__all__ = [
"Agent",
"ResponsiveAgent",
"AssistantAgent",
"UserProxyAgent",
"GroupChat",
"GroupChatManager",
]
2 changes: 1 addition & 1 deletion flaml/autogen/agentchat/contrib/math_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
default_auto_reply=default_auto_reply,
**kwargs,
)
self.register_auto_reply(Agent, self._generate_math_reply, 1)
self.register_auto_reply(Agent, MathUserProxyAgent._generate_math_reply, 1)
# fixed var
self._max_invalid_q_per_step = max_invalid_q_per_step

Expand Down
94 changes: 55 additions & 39 deletions flaml/autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,63 @@
from dataclasses import dataclass
import sys
from typing import Dict, List, Optional, Union
from .agent import Agent
from .responsive_agent import ResponsiveAgent


class GroupChatManager(ResponsiveAgent):
"""(WIP) A chat manager agent that can manage a group chat of multiple agents."""
@dataclass
class GroupChat:
"""A group chat class that contains a list of agents and the maximum number of rounds."""

agents: List[Agent]
max_round: int
messages: List[Dict]
max_round: int = 10

def _participant_roles(self):
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
@property
def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
return [agent.name for agent in self.agents]

def reset(self):
"""Reset the group chat."""
self.messages.clear()

def agent_by_name(self, name: str) -> Agent:
"""Find the next speaker based on the message."""
return self.agents[self.agent_names.index(name)]

def _select_speaker_msg(self):
def next_agent(self, agent: Agent) -> Agent:
"""Return the next agent in the list."""
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]

def select_speaker_msg(self):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}. Read the following conversation.
Then select the next role from {self._agent_names} to play. Only return the role."""
Then select the next role from {self.agent_names} to play. Only return the role."""

def select_speaker(self, last_speaker: Agent, selctor: ResponsiveAgent):
"""Select the next speaker."""
selctor.update_system_message(self.select_speaker_msg())
final, name = selctor.generate_oai_reply(self.messages)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
return self.next_agent(last_speaker)
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker)

def _participant_roles(self):
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])


class GroupChatManager(ResponsiveAgent):
"""(WIP) A chat manager agent that can manage a group chat of multiple agents."""

def __init__(
self,
max_round: Optional[int] = 10,
groupchat: GroupChat,
name: Optional[str] = "chat_manager",
# unlimited consecutive auto reply by default
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
Expand All @@ -33,56 +70,35 @@ def __init__(
name=name,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
system_message=system_message,
**kwargs,
)
self.register_auto_reply(Agent, self._generate_reply_for_participant)
self.max_round = max_round
self._agent_names = []
self._messages = []
self.register_auto_reply(Agent, GroupChatManager.run_chat, context=groupchat, reset_context=GroupChat.reset)
# self._random = random.Random(seed)

def _generate_reply_for_participant(
def run_chat(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
context: Optional[GroupChat] = None,
) -> Union[str, Dict, None]:
self._agent_names = [agent.name for agent in self.agents]
"""Run a group chat."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
for i in range(self.max_round):
for i in range(context.max_round):
# set the name to speaker's name if the role is not function
if message["role"] != "function":
message["name"] = speaker.name
self._messages.append(message)
context.messages.append(message)
# broadcast the message to all agents except the speaker
for agent in self.agents:
for agent in context.agents:
if agent != speaker:
self.send(message, agent, request_reply=False)
if i != self.max_round - 1:
if i != context.max_round - 1:
# speaker selection msg from an agent
speaker = self._select_speaker(speaker)
speaker = context.select_speaker(speaker, self)
speaker.send(speaker.generate_reply(sender=self), self, request_reply=False)
message = self.last_message(speaker)
return True, None

def _select_speaker(self, last_speaker: Agent):
"""Select the next speaker."""
self.update_system_message(self._select_speaker_msg())
final, name = self._generate_oai_reply(self._messages)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
return self.agents[(self._agent_names.index(last_speaker.name) + 1) % len(self._agent_names)]
try:
return self.agent_by_name(name)
except ValueError:
return self.agents[(self._agent_names.index(last_speaker.name) + 1) % len(self._agent_names)]

def agent_by_name(self, name: str) -> Agent:
"""Find the next speaker based on the message."""
return self.agents[self._agent_names.index(name)]

def reset(self):
super().reset()
self._messages.clear()
121 changes: 98 additions & 23 deletions flaml/autogen/agentchat/responsive_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
import copy
import json
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from flaml.autogen import oai
from .agent import Agent
from flaml.autogen.code_utils import DEFAULT_MODEL, UNKNOWN, execute_code, extract_code, infer_lang
Expand Down Expand Up @@ -108,26 +109,64 @@ def __init__(
self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply)
self._function_map = {} if function_map is None else function_map
self._default_auto_reply = default_auto_reply
self._class_specific_reply = []
self._reply_func_list = []
self.reply_at_receive = defaultdict(bool)
self.register_auto_reply(Agent, self._generate_oai_reply)
self.register_auto_reply(Agent, self._generate_code_execution_reply)
self.register_auto_reply(Agent, self._generate_function_call_reply)
self.register_auto_reply(Agent, self._check_termination_and_human_reply)
self.register_auto_reply(Agent, ResponsiveAgent.generate_oai_reply)
self.register_auto_reply(Agent, ResponsiveAgent.generate_code_execution_reply)
self.register_auto_reply(Agent, ResponsiveAgent.generate_function_call_reply)
self.register_auto_reply(Agent, ResponsiveAgent.check_termination_and_human_reply)

def register_auto_reply(self, class_type, reply_func: Callable, position: int = 0):
"""Register a class-specific reply function.
def register_auto_reply(
self,
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func: Callable,
position: Optional[int] = 0,
context: Optional[Any] = None,
reset_context: Optional[Callable] = None,
):
"""Register a reply function.
The class-specific reply function will be called when the sender is an instance of the class_type.
The reply function will be called when the trigger matches the sender.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.
Args:
class_type (Class): the class type.
trigger (Agent class, str, Agent instance, callable, or list): the trigger.
- If a class is provided, the reply function will be called when the sender is an instance of the class.
- If a string is provided, the reply function will be called when the sender's name matches the string.
- If an agent instance is provided, the reply function will be called when the sender is the agent instance.
- If a callable is provided, the reply function will be called when the callable returns True.
- If a list is provided, the reply function will be called when any of the triggers in the list is activated.
reply_func (Callable): the reply function.
The function takes a recipient agent, a list of messages, a sender agent and a context as input and returns a reply message.
```python
def reply_func(
recipient: ResponsiveAgent,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
context: Optional[Any] = None,
) -> Union[str, Dict, None]:
```
position (int): the position of the reply function in the reply function list.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.
context (Any): the context to be passed to the reply function.
When an agent is reset, the context will be reset to the original value.
reset_context (Callable): the function to reset the context.
The function returns None. Signature: ```def reset_context(context: Any)```
"""
self._class_specific_reply.insert(position, (class_type, reply_func))
if not isinstance(trigger, (type, str, Agent, Callable, list)):
raise ValueError("trigger must be a class, a string, an agent, a callable or a list.")
self._reply_func_list.insert(
position,
{
"trigger": trigger,
"reply_func": reply_func,
"context": copy.copy(context),
"init_context": context,
"reset_context": reset_context,
},
)

@property
def system_message(self):
Expand Down Expand Up @@ -362,6 +401,11 @@ def reset(self):
self.clear_history()
self.reset_consecutive_auto_reply_counter()
self.stop_reply_at_receive()
for reply_func_tuple in self._reply_func_list:
if reply_func_tuple["reset_context"] is not None:
reply_func_tuple["reset_context"](reply_func_tuple["context"])
else:
reply_func_tuple["context"] = copy.copy(reply_func_tuple["init_context"])

def stop_reply_at_receive(self, sender: Optional[Agent] = None):
"""Reset the reply_at_receive of the sender."""
Expand All @@ -388,28 +432,34 @@ def clear_history(self, agent: Optional[Agent] = None):
else:
self._oai_messages[agent].clear()

def _generate_oai_reply(
def generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
context: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
if self.llm_config is False:
"""Generate a reply using autogen.oai."""
llm_config = self.llm_config if context is None else context
if llm_config is False:
return False, None
if messages is None:
messages = self._oai_messages[sender]

# TODO: #1143 handle token limit exceeded error
response = oai.ChatCompletion.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **self.llm_config
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **llm_config
)
return True, oai.ChatCompletion.extract_text_or_function_call(response)[0]

def _generate_code_execution_reply(
def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
context: Optional[Any] = None,
):
if self._code_execution_config is False:
"""Generate a reply using code execution."""
code_execution_config = context if context is not None else self._code_execution_config
if code_execution_config is False:
return False, None
if messages is None:
messages = self._oai_messages[sender]
Expand All @@ -426,11 +476,15 @@ def _generate_code_execution_reply(
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}"

def _generate_function_call_reply(
def generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
context: Optional[Any] = None,
):
"""Generate a reply using function call."""
if context is None:
context = self
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
Expand All @@ -439,11 +493,15 @@ def _generate_function_call_reply(
return True, func_return
return False, None

def _check_termination_and_human_reply(
def check_termination_and_human_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
context: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Check if the conversation should be terminated, and if human reply is provided."""
if context is None:
context = self
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
Expand Down Expand Up @@ -538,15 +596,32 @@ def generate_reply(
"""
assert messages is not None or sender is not None, "Either messages or sender must be provided."
if sender is not None:
for class_specifc_reply in self._class_specific_reply:
if isinstance(sender, class_specifc_reply[0]) and (
not exclude or class_specifc_reply[1] not in exclude
):
final, reply = class_specifc_reply[1](messages, sender)
for reply_func_tuple in self._reply_func_list:
if exclude and reply_func_tuple["reply_func"] in exclude:
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func_tuple["reply_func"](
self, messages=messages, sender=sender, context=reply_func_tuple["context"]
)
if final:
return reply
return self._default_auto_reply

def _match_trigger(self, trigger, sender):
"""Check if the sender matches the trigger."""
if isinstance(trigger, str):
return trigger == sender.name
elif isinstance(trigger, type):
return isinstance(sender, trigger)
elif isinstance(trigger, Agent):
return trigger == sender
elif isinstance(trigger, Callable):
return trigger(sender)
elif isinstance(trigger, list):
return any(self._match_trigger(t, sender) for t in trigger)
else:
raise ValueError(f"Unsupported trigger type: {type(trigger)}")

def get_human_input(self, prompt: str) -> str:
"""Get human input.
Expand Down
2 changes: 1 addition & 1 deletion flaml/autogen/agentchat/user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class UserProxyAgent(ResponsiveAgent):
UserProxyAgent is a subclass of ResponsiveAgent configured with `human_input_mode` to ALWAYS
and `llm_config` to False. By default, the agent will prompt for human input every time a message is received.
Code execution is enabled by default. LLM-based auto reply is disabled by default.
To modify auto reply, register a method with `register_class_specific_reply`.
To modify auto reply, register a method with (`register_auto_reply`)[responsive_agent#register_auto_reply].
The method should have a similar signature with `_generate_oai_reply` method.
To modify the way to get human input, override `get_human_input` method.
To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`,
Expand Down
Loading

0 comments on commit a603e6d

Please sign in to comment.