Skip to content

Commit

Permalink
- Remove unnecessary message type check
Browse files Browse the repository at this point in the history
- Rename class to `SourceMatchTermination`
  • Loading branch information
thainduy committed Nov 13, 2024
1 parent d09163b commit f3470bc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
4 changes: 2 additions & 2 deletions python/packages/autogen-ext/src/autogen_ext/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._agent_name_termination import AgentNameTermination
from ._source_match_termination import SourceMatchTermination

__all__ = ["AgentNameTermination"]
__all__ = ["SourceMatchTermination"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Sequence, List

from autogen_agentchat.base import TerminationCondition, TerminatedException
from autogen_agentchat.messages import StopMessage, AgentMessage, ChatMessage
from autogen_agentchat.messages import StopMessage, AgentMessage


class AgentNameTermination(TerminationCondition):
class SourceMatchTermination(TerminationCondition):
"""Terminate the conversation after a specific agent responds.
Args:
Expand All @@ -29,9 +29,8 @@ async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None
return None
last_message = messages[-1]
if last_message.source in self._agents:
if isinstance(last_message, ChatMessage):
self._terminated = True
return StopMessage(content=f"Agent '{last_message.source}' answered", source="AgentNameTermination")
self._terminated = True
return StopMessage(content=f"Agent '{last_message.source}' answered", source="SourceMatchTermination")
return None

async def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from autogen_agentchat.base import TerminatedException
from autogen_agentchat.messages import TextMessage, StopMessage
from autogen_ext.task import AgentNameTermination
from autogen_ext.task import SourceMatchTermination


@pytest.mark.asyncio
async def test_agent_name_termination() -> None:
termination = AgentNameTermination(agents=["Assistant"])
termination = SourceMatchTermination(agents=["Assistant"])
assert await termination([]) is None

continue_messages = [
Expand Down

0 comments on commit f3470bc

Please sign in to comment.