|
1 | 1 | # -*- coding: utf-8 -*- |
| 2 | +# pylint: disable=too-many-lines |
2 | 3 | # mypy: disable-error-code="index" |
3 | 4 | """Test toolkit module in agentscope.""" |
4 | 5 | import asyncio |
@@ -974,6 +975,102 @@ def example_func( |
974 | 975 | "Received: a=1, b=test, c=[1, 2, 3], d=xyz", |
975 | 976 | ) |
976 | 977 |
|
| 978 | + async def test_func_name_parameter(self) -> None: |
| 979 | + """Test func_name parameter for custom tool renaming.""" |
| 980 | + # Test 1: Regular function with func_name |
| 981 | + self.toolkit.register_tool_function( |
| 982 | + sync_func, |
| 983 | + func_name="custom_sync_func", |
| 984 | + ) |
| 985 | + self.assertIn("custom_sync_func", self.toolkit.tools) |
| 986 | + self.assertNotIn("sync_func", self.toolkit.tools) |
| 987 | + |
| 988 | + # Verify the JSON schema uses the custom name |
| 989 | + schemas = self.toolkit.get_json_schemas() |
| 990 | + self.assertEqual(schemas[0]["function"]["name"], "custom_sync_func") |
| 991 | + |
| 992 | + # Verify original_name is set when func_name is provided |
| 993 | + tool_obj = self.toolkit.tools["custom_sync_func"] |
| 994 | + self.assertEqual(tool_obj.original_name, "sync_func") |
| 995 | + |
| 996 | + # Test 2: Regular function without func_name (backward compatibility) |
| 997 | + def another_func(x: int) -> ToolResponse: |
| 998 | + """Another test function.""" |
| 999 | + return ToolResponse(content=[TextBlock(type="text", text=str(x))]) |
| 1000 | + |
| 1001 | + self.toolkit.register_tool_function(another_func) |
| 1002 | + self.assertIn("another_func", self.toolkit.tools) |
| 1003 | + tool_obj = self.toolkit.tools["another_func"] |
| 1004 | + self.assertIsNone(tool_obj.original_name) |
| 1005 | + |
| 1006 | + # Test 3: Partial function with func_name |
| 1007 | + partial_func = partial(sync_func, arg1=10) |
| 1008 | + self.toolkit.register_tool_function( |
| 1009 | + partial_func, |
| 1010 | + func_name="custom_partial_func", |
| 1011 | + ) |
| 1012 | + self.assertIn("custom_partial_func", self.toolkit.tools) |
| 1013 | + tool_obj = self.toolkit.tools["custom_partial_func"] |
| 1014 | + self.assertEqual(tool_obj.original_name, "sync_func") |
| 1015 | + |
| 1016 | + # Test 4: func_name with namesake_strategy="rename" |
| 1017 | + self.toolkit.register_tool_function( |
| 1018 | + sync_func, |
| 1019 | + func_name="custom_sync_func", # Already exists |
| 1020 | + namesake_strategy="rename", |
| 1021 | + ) |
| 1022 | + # Should create a new name with random suffix |
| 1023 | + renamed_tools = [ |
| 1024 | + name |
| 1025 | + for name in self.toolkit.tools |
| 1026 | + if name.startswith("custom_sync_func_") |
| 1027 | + ] |
| 1028 | + self.assertEqual(len(renamed_tools), 1) |
| 1029 | + renamed_name = renamed_tools[0] |
| 1030 | + tool_obj = self.toolkit.tools[renamed_name] |
| 1031 | + # original_name should be "sync_func" (the true original function name) |
| 1032 | + # because original_name records the actual function name, not the |
| 1033 | + # func_name |
| 1034 | + self.assertEqual(tool_obj.original_name, "sync_func") |
| 1035 | + |
| 1036 | + # Test 5: func_name with namesake_strategy="rename" but no func_name |
| 1037 | + # (should use original function name as original_name) |
| 1038 | + def test_func() -> ToolResponse: |
| 1039 | + """Test function.""" |
| 1040 | + return ToolResponse(content=[TextBlock(type="text", text="test")]) |
| 1041 | + |
| 1042 | + self.toolkit.register_tool_function(test_func) |
| 1043 | + self.toolkit.register_tool_function( |
| 1044 | + test_func, |
| 1045 | + namesake_strategy="rename", |
| 1046 | + ) |
| 1047 | + # Find the renamed tool |
| 1048 | + renamed_test_tools = [ |
| 1049 | + name |
| 1050 | + for name in self.toolkit.tools |
| 1051 | + if name.startswith("test_func_") |
| 1052 | + ] |
| 1053 | + self.assertEqual(len(renamed_test_tools), 1) |
| 1054 | + renamed_test_name = renamed_test_tools[0] |
| 1055 | + tool_obj = self.toolkit.tools[renamed_test_name] |
| 1056 | + # original_name should be "test_func" (the original function name) |
| 1057 | + self.assertEqual(tool_obj.original_name, "test_func") |
| 1058 | + |
| 1059 | + # Test 6: Verify tool can be called with custom name |
| 1060 | + res = await self.toolkit.call_tool_function( |
| 1061 | + ToolUseBlock( |
| 1062 | + type="tool_use", |
| 1063 | + id="123", |
| 1064 | + name="custom_sync_func", |
| 1065 | + input={"arg1": 42}, |
| 1066 | + ), |
| 1067 | + ) |
| 1068 | + async for chunk in res: |
| 1069 | + self.assertEqual( |
| 1070 | + chunk.content[0]["text"], |
| 1071 | + "arg1: 42, arg2: None", |
| 1072 | + ) |
| 1073 | + |
977 | 1074 | async def asyncTearDown(self) -> None: |
978 | 1075 | """Clean up after each test.""" |
979 | 1076 | self.toolkit = None |
0 commit comments