Skip to content

Commit 5f3bb68

Browse files
committed
add unittest
1 parent dd375e9 commit 5f3bb68

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

tests/toolkit_basic_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
# pylint: disable=too-many-lines
23
# mypy: disable-error-code="index"
34
"""Test toolkit module in agentscope."""
45
import asyncio
@@ -974,6 +975,102 @@ def example_func(
974975
"Received: a=1, b=test, c=[1, 2, 3], d=xyz",
975976
)
976977

978+
async def test_func_name_parameter(self) -> None:
979+
"""Test func_name parameter for custom tool renaming."""
980+
# Test 1: Regular function with func_name
981+
self.toolkit.register_tool_function(
982+
sync_func,
983+
func_name="custom_sync_func",
984+
)
985+
self.assertIn("custom_sync_func", self.toolkit.tools)
986+
self.assertNotIn("sync_func", self.toolkit.tools)
987+
988+
# Verify the JSON schema uses the custom name
989+
schemas = self.toolkit.get_json_schemas()
990+
self.assertEqual(schemas[0]["function"]["name"], "custom_sync_func")
991+
992+
# Verify original_name is set when func_name is provided
993+
tool_obj = self.toolkit.tools["custom_sync_func"]
994+
self.assertEqual(tool_obj.original_name, "sync_func")
995+
996+
# Test 2: Regular function without func_name (backward compatibility)
997+
def another_func(x: int) -> ToolResponse:
998+
"""Another test function."""
999+
return ToolResponse(content=[TextBlock(type="text", text=str(x))])
1000+
1001+
self.toolkit.register_tool_function(another_func)
1002+
self.assertIn("another_func", self.toolkit.tools)
1003+
tool_obj = self.toolkit.tools["another_func"]
1004+
self.assertIsNone(tool_obj.original_name)
1005+
1006+
# Test 3: Partial function with func_name
1007+
partial_func = partial(sync_func, arg1=10)
1008+
self.toolkit.register_tool_function(
1009+
partial_func,
1010+
func_name="custom_partial_func",
1011+
)
1012+
self.assertIn("custom_partial_func", self.toolkit.tools)
1013+
tool_obj = self.toolkit.tools["custom_partial_func"]
1014+
self.assertEqual(tool_obj.original_name, "sync_func")
1015+
1016+
# Test 4: func_name with namesake_strategy="rename"
1017+
self.toolkit.register_tool_function(
1018+
sync_func,
1019+
func_name="custom_sync_func", # Already exists
1020+
namesake_strategy="rename",
1021+
)
1022+
# Should create a new name with random suffix
1023+
renamed_tools = [
1024+
name
1025+
for name in self.toolkit.tools
1026+
if name.startswith("custom_sync_func_")
1027+
]
1028+
self.assertEqual(len(renamed_tools), 1)
1029+
renamed_name = renamed_tools[0]
1030+
tool_obj = self.toolkit.tools[renamed_name]
1031+
# original_name should be "sync_func" (the true original function name)
1032+
# because original_name records the actual function name, not the
1033+
# func_name
1034+
self.assertEqual(tool_obj.original_name, "sync_func")
1035+
1036+
# Test 5: func_name with namesake_strategy="rename" but no func_name
1037+
# (should use original function name as original_name)
1038+
def test_func() -> ToolResponse:
1039+
"""Test function."""
1040+
return ToolResponse(content=[TextBlock(type="text", text="test")])
1041+
1042+
self.toolkit.register_tool_function(test_func)
1043+
self.toolkit.register_tool_function(
1044+
test_func,
1045+
namesake_strategy="rename",
1046+
)
1047+
# Find the renamed tool
1048+
renamed_test_tools = [
1049+
name
1050+
for name in self.toolkit.tools
1051+
if name.startswith("test_func_")
1052+
]
1053+
self.assertEqual(len(renamed_test_tools), 1)
1054+
renamed_test_name = renamed_test_tools[0]
1055+
tool_obj = self.toolkit.tools[renamed_test_name]
1056+
# original_name should be "test_func" (the original function name)
1057+
self.assertEqual(tool_obj.original_name, "test_func")
1058+
1059+
# Test 6: Verify tool can be called with custom name
1060+
res = await self.toolkit.call_tool_function(
1061+
ToolUseBlock(
1062+
type="tool_use",
1063+
id="123",
1064+
name="custom_sync_func",
1065+
input={"arg1": 42},
1066+
),
1067+
)
1068+
async for chunk in res:
1069+
self.assertEqual(
1070+
chunk.content[0]["text"],
1071+
"arg1: 42, arg2: None",
1072+
)
1073+
9771074
async def asyncTearDown(self) -> None:
9781075
"""Clean up after each test."""
9791076
self.toolkit = None

0 commit comments

Comments
 (0)