From 813157dd75cc95275c51d90bc6cfb7382d88ccc2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 16 Dec 2024 15:49:48 +0100 Subject: [PATCH] feat: make `Tool.from_function` more configurable (#155) * feat: make Tool.from_function more configurable * better test --- haystack_experimental/dataclasses/tool.py | 32 +++++++++-------- test/dataclasses/test_tool.py | 44 +++++++++++++++++++---- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index 4f28f184..33719524 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -4,7 +4,7 @@ import inspect from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional from haystack.lazy_imports import LazyImport from haystack.utils import deserialize_callable, serialize_callable @@ -106,7 +106,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool": return cls(**data) @classmethod - def from_function(cls, function: Callable, docstring_as_desc: bool = True) -> "Tool": + def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool": """ Create a Tool instance from a function. @@ -144,8 +144,11 @@ def get_weather( The function to be converted into a Tool. The function must include type hints for all parameters. If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description. - :param docstring_as_desc: - Whether to use the function's docstring as the tool description. + :param name: + The name of the tool. If not provided, the name of the function will be used. + :param description: + The description of the tool. If not provided, the docstring of the function will be used. + To intentionally leave the description empty, pass an empty string. :returns: The Tool created from the function. @@ -155,9 +158,8 @@ def get_weather( :raises SchemaGenerationError: If there is an error generating the JSON schema for the Tool. """ - tool_description = "" - if docstring_as_desc and function.__doc__: - tool_description = function.__doc__ + + tool_description = description if description is not None else (function.__doc__ or "") signature = inspect.signature(function) @@ -165,17 +167,17 @@ def get_weather( fields: Dict[str, Any] = {} descriptions = {} - for name, param in signature.parameters.items(): + for param_name, param in signature.parameters.items(): if param.annotation is param.empty: - raise ValueError(f"Function '{function.__name__}': parameter '{name}' does not have a type hint.") + raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") # if the parameter has not a default value, Pydantic requires an Ellipsis (...) # to explicitly indicate that the parameter is required default = param.default if param.default is not param.empty else ... - fields[name] = (param.annotation, default) + fields[param_name] = (param.annotation, default) if hasattr(param.annotation, "__metadata__"): - descriptions[name] = param.annotation.__metadata__[0] + descriptions[param_name] = param.annotation.__metadata__[0] # create Pydantic model and generate JSON schema try: @@ -190,11 +192,11 @@ def get_weather( _remove_title_from_schema(schema) # add parameters descriptions to the schema - for name, description in descriptions.items(): - if name in schema["properties"]: - schema["properties"][name]["description"] = description + for param_name, param_description in descriptions.items(): + if param_name in schema["properties"]: + schema["properties"][param_name]["description"] = param_description - return Tool(name=function.__name__, description=tool_description, parameters=schema, function=function) + return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) def _remove_title_from_schema(schema: Dict[str, Any]): diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py index 727c5963..da97df79 100644 --- a/test/dataclasses/test_tool.py +++ b/test/dataclasses/test_tool.py @@ -26,6 +26,10 @@ def get_weather_report(city: str) -> str: parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} +def function_with_docstring(city: str) -> str: + """Get weather report for a city.""" + return f"Weather report for {city}: 20°C, sunny" + class TestTool: def test_init(self): @@ -96,14 +100,9 @@ def test_from_dict(self): assert tool.parameters == parameters assert tool.function == get_weather_report - def test_from_function_docstring_as_desc(self): - def function_with_docstring(city: str) -> str: - """Get weather report for a city.""" - return f"Weather report for {city}: 20°C, sunny" - + def test_from_function_description_from_docstring(self): tool = Tool.from_function( function=function_with_docstring, - docstring_as_desc=True, ) assert tool.name == "function_with_docstring" @@ -115,9 +114,10 @@ def function_with_docstring(city: str) -> str: } assert tool.function == function_with_docstring + def test_from_function_with_empty_description(self): another_tool = Tool.from_function( function=function_with_docstring, - docstring_as_desc=False, + description="", ) assert another_tool.name == "function_with_docstring" @@ -129,6 +129,36 @@ def function_with_docstring(city: str) -> str: } assert another_tool.function == function_with_docstring + def test_from_function_with_custom_description(self): + another_tool = Tool.from_function( + function=function_with_docstring, + description="custom description", + ) + + assert another_tool.name == "function_with_docstring" + assert another_tool.description == "custom description" + assert another_tool.parameters == { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + assert another_tool.function == function_with_docstring + + def test_from_function_with_custom_name(self): + tool = Tool.from_function( + function=function_with_docstring, + name="custom_name", + ) + + assert tool.name == "custom_name" + assert tool.description == "Get weather report for a city." + assert tool.parameters == { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + assert tool.function == function_with_docstring + def test_from_function_missing_type_hint(self): def function_missing_type_hint(city) -> str: return f"Weather report for {city}: 20°C, sunny"