Skip to content

Commit 0b97e1f

Browse files
committed
fix InlineAgent & ComprehendAgent & add unit tests
1 parent 1f17341 commit 0b97e1f

File tree

5 files changed

+480
-20
lines changed

5 files changed

+480
-20
lines changed

python/setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = multi_agent_orchestrator
3-
version = 0.1.6
3+
version = 0.1.7
44
author = Anthony Bernabeu, Corneliu Croitoru
55
66
description = Multi-agent orchestrator framework

python/src/multi_agent_orchestrator/agents/bedrock_inline_agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# BedrockInlineAgentOptions Dataclass
1616
@dataclass
1717
class BedrockInlineAgentOptions(AgentOptions):
18+
model_id: Optional[str] = None
19+
region: Optional[str] = None
1820
inference_config: Optional[Dict[str, Any]] = None
1921
client: Optional[Any] = None
2022
bedrock_agent_client: Optional[Any] = None
@@ -71,6 +73,8 @@ def __init__(self, options: BedrockInlineAgentOptions):
7173
else:
7274
self.client = boto3.client('bedrock-runtime')
7375

76+
self.model_id: str = options.model_id or BEDROCK_MODEL_ID_CLAUDE_3_HAIKU
77+
7478
# Initialize bedrock agent client
7579
if options.bedrock_agent_client:
7680
self.bedrock_agent_client = options.bedrock_agent_client

python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,39 @@
44
from .agent import Agent, AgentOptions
55
import boto3
66
from botocore.config import Config
7+
import os
8+
from dataclasses import dataclass
9+
710

811
# Type alias for CheckFunction
912
CheckFunction = Callable[[str], str]
1013

14+
@dataclass
1115
class ComprehendFilterAgentOptions(AgentOptions):
12-
def __init__(self,
13-
enable_sentiment_check: bool = True,
14-
enable_pii_check: bool = True,
15-
enable_toxicity_check: bool = True,
16-
sentiment_threshold: float = 0.7,
17-
toxicity_threshold: float = 0.7,
18-
allow_pii: bool = False,
19-
language_code: str = 'en',
20-
**kwargs):
21-
super().__init__(**kwargs)
22-
self.enable_sentiment_check = enable_sentiment_check
23-
self.enable_pii_check = enable_pii_check
24-
self.enable_toxicity_check = enable_toxicity_check
25-
self.sentiment_threshold = sentiment_threshold
26-
self.toxicity_threshold = toxicity_threshold
27-
self.allow_pii = allow_pii
28-
self.language_code = language_code
16+
enable_sentiment_check: bool = True
17+
enable_pii_check: bool = True
18+
enable_toxicity_check: bool = True
19+
sentiment_threshold: float = 0.7
20+
toxicity_threshold: float = 0.7
21+
allow_pii: bool = False
22+
language_code: str = 'en'
23+
region: Optional[str] = None
24+
client: Optional[Any] = None
2925

3026
class ComprehendFilterAgent(Agent):
3127
def __init__(self, options: ComprehendFilterAgentOptions):
3228
super().__init__(options)
3329

34-
config = Config(region_name=options.region) if options.region else None
35-
self.comprehend_client = boto3.client('comprehend', config=config)
30+
if options.client:
31+
self.comprehend_client = options.client
32+
else:
33+
if options.region:
34+
self.client = boto3.client(
35+
'comprehend',
36+
region_name=options.region or os.environ.get('AWS_REGION')
37+
)
38+
else:
39+
self.client = boto3.client('comprehend')
3640

3741
self.custom_checks: List[CheckFunction] = []
3842

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import unittest
2+
from unittest.mock import Mock
3+
import json
4+
from typing import Dict, Any
5+
6+
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole
7+
from multi_agent_orchestrator.agents import BedrockInlineAgent, BedrockInlineAgentOptions
8+
9+
class TestBedrockInlineAgent(unittest.IsolatedAsyncioTestCase):
10+
async def asyncSetUp(self):
11+
# Mock clients
12+
self.mock_bedrock_client = Mock()
13+
self.mock_bedrock_agent_client = Mock()
14+
15+
# Sample action groups and knowledge bases
16+
self.action_groups = [
17+
{
18+
"actionGroupName": "TestActionGroup1",
19+
"description": "Test action group 1 description"
20+
},
21+
{
22+
"actionGroupName": "TestActionGroup2",
23+
"description": "Test action group 2 description"
24+
}
25+
]
26+
27+
self.knowledge_bases = [
28+
{
29+
"knowledgeBaseId": "kb1",
30+
"description": "Test knowledge base 1"
31+
},
32+
{
33+
"knowledgeBaseId": "kb2",
34+
"description": "Test knowledge base 2"
35+
}
36+
]
37+
38+
# Create agent instance
39+
self.agent = BedrockInlineAgent(
40+
BedrockInlineAgentOptions(
41+
name="Test Agent",
42+
description="Test agent description",
43+
client=self.mock_bedrock_client,
44+
bedrock_agent_client=self.mock_bedrock_agent_client,
45+
action_groups_list=self.action_groups,
46+
knowledge_bases=self.knowledge_bases
47+
)
48+
)
49+
50+
async def test_initialization(self):
51+
"""Test agent initialization and configuration"""
52+
self.assertEqual(self.agent.name, "Test Agent")
53+
self.assertEqual(self.agent.description, "Test agent description")
54+
self.assertEqual(len(self.agent.action_groups_list), 2)
55+
self.assertEqual(len(self.agent.knowledge_bases), 2)
56+
self.assertEqual(self.agent.tool_config['toolMaxRecursions'], 1)
57+
58+
async def test_process_request_without_tool_use(self):
59+
"""Test processing a request that doesn't require tool use"""
60+
# Mock the converse response
61+
mock_response = {
62+
'output': {
63+
'message': {
64+
'role': 'assistant',
65+
'content': [{'text': 'Test response'}]
66+
}
67+
}
68+
}
69+
self.mock_bedrock_client.converse.return_value = mock_response
70+
71+
# Test input
72+
input_text = "Hello"
73+
chat_history = []
74+
75+
# Process request
76+
response = await self.agent.process_request(
77+
input_text=input_text,
78+
user_id='test_user',
79+
session_id='test_session',
80+
chat_history=chat_history
81+
)
82+
83+
# Verify response
84+
self.assertIsInstance(response, ConversationMessage)
85+
self.assertEqual(response.role, ParticipantRole.ASSISTANT.value)
86+
self.assertEqual(response.content[0]['text'], 'Test response')
87+
88+
async def test_process_request_with_tool_use(self):
89+
"""Test processing a request that requires tool use"""
90+
# Mock the converse response with tool use
91+
tool_use_response = {
92+
'output': {
93+
'message': {
94+
'role': 'assistant',
95+
'content': [{
96+
'toolUse': {
97+
'name': 'inline_agent_creation',
98+
'input': {
99+
'action_group_names': ['TestActionGroup1'],
100+
'knowledge_bases': ['kb1'],
101+
'description': 'Test description',
102+
'user_request': 'Test request'
103+
}
104+
}
105+
}]
106+
}
107+
}
108+
}
109+
self.mock_bedrock_client.converse.return_value = tool_use_response
110+
111+
# Mock the inline agent response
112+
mock_completion = {
113+
'chunk': {
114+
'bytes': b'Inline agent response'
115+
}
116+
}
117+
self.mock_bedrock_agent_client.invoke_inline_agent.return_value = {
118+
'completion': [mock_completion]
119+
}
120+
121+
# Test input
122+
input_text = "Use inline agent"
123+
chat_history = []
124+
125+
# Process request
126+
response = await self.agent.process_request(
127+
input_text=input_text,
128+
user_id='test_user',
129+
session_id='test_session',
130+
chat_history=chat_history
131+
)
132+
133+
# Verify response
134+
self.assertIsInstance(response, ConversationMessage)
135+
self.assertEqual(response.role, ParticipantRole.ASSISTANT.value)
136+
self.assertEqual(response.content[0]['text'], 'Inline agent response')
137+
138+
# Verify inline agent was called with correct parameters
139+
self.mock_bedrock_agent_client.invoke_inline_agent.assert_called_once()
140+
call_kwargs = self.mock_bedrock_agent_client.invoke_inline_agent.call_args[1]
141+
self.assertEqual(len(call_kwargs['actionGroups']), 1)
142+
self.assertEqual(len(call_kwargs['knowledgeBases']), 1)
143+
self.assertEqual(call_kwargs['inputText'], 'Test request')
144+
145+
async def test_error_handling(self):
146+
"""Test error handling in process_request"""
147+
# Mock the converse method to raise an exception
148+
self.mock_bedrock_client.converse.side_effect = Exception("Test error")
149+
150+
# Test input
151+
input_text = "Hello"
152+
chat_history = []
153+
154+
# Verify exception is raised
155+
with self.assertRaises(Exception) as context:
156+
await self.agent.process_request(
157+
input_text=input_text,
158+
user_id='test_user',
159+
session_id='test_session',
160+
chat_history=chat_history
161+
)
162+
163+
self.assertTrue("Test error" in str(context.exception))
164+
165+
async def test_system_prompt_formatting(self):
166+
"""Test system prompt formatting and template replacement"""
167+
# Test with custom variables
168+
test_variables = {
169+
'test_var': 'test_value'
170+
}
171+
self.agent.set_system_prompt(
172+
template="Test template with {{test_var}}",
173+
variables=test_variables
174+
)
175+
176+
self.assertEqual(self.agent.system_prompt, "Test template with test_value")
177+
178+
async def test_inline_agent_tool_handler(self):
179+
"""Test the inline agent tool handler"""
180+
# Mock response content
181+
response = ConversationMessage(
182+
role=ParticipantRole.ASSISTANT.value,
183+
content=[{
184+
'toolUse': {
185+
'name': 'inline_agent_creation',
186+
'input': {
187+
'action_group_names': ['TestActionGroup1'],
188+
'knowledge_bases': ['kb1'],
189+
'description': 'Test description',
190+
'user_request': 'Test request'
191+
}
192+
}
193+
}]
194+
)
195+
196+
# Mock inline agent response
197+
mock_completion = {
198+
'chunk': {
199+
'bytes': b'Handler test response'
200+
}
201+
}
202+
self.mock_bedrock_agent_client.invoke_inline_agent.return_value = {
203+
'completion': [mock_completion]
204+
}
205+
206+
# Call handler
207+
result = await self.agent.inline_agent_tool_handler(
208+
session_id='test_session',
209+
response=response,
210+
conversation=[]
211+
)
212+
213+
# Verify result
214+
self.assertIsInstance(result, ConversationMessage)
215+
self.assertEqual(result.content[0]['text'], 'Handler test response')
216+
217+
async def test_custom_prompt_template(self):
218+
"""Test custom prompt template setup"""
219+
custom_template = "Custom template {{test_var}}"
220+
custom_variables = {"test_var": "test_value"}
221+
222+
self.agent.set_system_prompt(
223+
template=custom_template,
224+
variables=custom_variables
225+
)
226+
227+
self.assertEqual(self.agent.prompt_template, custom_template)
228+
self.assertEqual(self.agent.custom_variables, custom_variables)
229+
self.assertEqual(self.agent.system_prompt, "Custom template test_value")
230+
231+
if __name__ == '__main__':
232+
unittest.main()

0 commit comments

Comments
 (0)