Skip to content

Commit

Permalink
feat: make Tool.from_function more configurable (#155)
Browse files Browse the repository at this point in the history
* feat: make Tool.from_function more configurable

* better test
  • Loading branch information
anakin87 authored Dec 16, 2024
1 parent 46d3795 commit 813157d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
32 changes: 17 additions & 15 deletions haystack_experimental/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional

from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_callable, serialize_callable
Expand Down Expand Up @@ -106,7 +106,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool":
return cls(**data)

@classmethod
def from_function(cls, function: Callable, docstring_as_desc: bool = True) -> "Tool":
def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool":
"""
Create a Tool instance from a function.
Expand Down Expand Up @@ -144,8 +144,11 @@ def get_weather(
The function to be converted into a Tool.
The function must include type hints for all parameters.
If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description.
:param docstring_as_desc:
Whether to use the function's docstring as the tool description.
:param name:
The name of the tool. If not provided, the name of the function will be used.
:param description:
The description of the tool. If not provided, the docstring of the function will be used.
To intentionally leave the description empty, pass an empty string.
:returns:
The Tool created from the function.
Expand All @@ -155,27 +158,26 @@ def get_weather(
:raises SchemaGenerationError:
If there is an error generating the JSON schema for the Tool.
"""
tool_description = ""
if docstring_as_desc and function.__doc__:
tool_description = function.__doc__

tool_description = description if description is not None else (function.__doc__ or "")

signature = inspect.signature(function)

# collect fields (types and defaults) and descriptions from function parameters
fields: Dict[str, Any] = {}
descriptions = {}

for name, param in signature.parameters.items():
for param_name, param in signature.parameters.items():
if param.annotation is param.empty:
raise ValueError(f"Function '{function.__name__}': parameter '{name}' does not have a type hint.")
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.")

# if the parameter has not a default value, Pydantic requires an Ellipsis (...)
# to explicitly indicate that the parameter is required
default = param.default if param.default is not param.empty else ...
fields[name] = (param.annotation, default)
fields[param_name] = (param.annotation, default)

if hasattr(param.annotation, "__metadata__"):
descriptions[name] = param.annotation.__metadata__[0]
descriptions[param_name] = param.annotation.__metadata__[0]

# create Pydantic model and generate JSON schema
try:
Expand All @@ -190,11 +192,11 @@ def get_weather(
_remove_title_from_schema(schema)

# add parameters descriptions to the schema
for name, description in descriptions.items():
if name in schema["properties"]:
schema["properties"][name]["description"] = description
for param_name, param_description in descriptions.items():
if param_name in schema["properties"]:
schema["properties"][param_name]["description"] = param_description

return Tool(name=function.__name__, description=tool_description, parameters=schema, function=function)
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)


def _remove_title_from_schema(schema: Dict[str, Any]):
Expand Down
44 changes: 37 additions & 7 deletions test/dataclasses/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def get_weather_report(city: str) -> str:

parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}

def function_with_docstring(city: str) -> str:
"""Get weather report for a city."""
return f"Weather report for {city}: 20°C, sunny"


class TestTool:
def test_init(self):
Expand Down Expand Up @@ -96,14 +100,9 @@ def test_from_dict(self):
assert tool.parameters == parameters
assert tool.function == get_weather_report

def test_from_function_docstring_as_desc(self):
def function_with_docstring(city: str) -> str:
"""Get weather report for a city."""
return f"Weather report for {city}: 20°C, sunny"

def test_from_function_description_from_docstring(self):
tool = Tool.from_function(
function=function_with_docstring,
docstring_as_desc=True,
)

assert tool.name == "function_with_docstring"
Expand All @@ -115,9 +114,10 @@ def function_with_docstring(city: str) -> str:
}
assert tool.function == function_with_docstring

def test_from_function_with_empty_description(self):
another_tool = Tool.from_function(
function=function_with_docstring,
docstring_as_desc=False,
description="",
)

assert another_tool.name == "function_with_docstring"
Expand All @@ -129,6 +129,36 @@ def function_with_docstring(city: str) -> str:
}
assert another_tool.function == function_with_docstring

def test_from_function_with_custom_description(self):
another_tool = Tool.from_function(
function=function_with_docstring,
description="custom description",
)

assert another_tool.name == "function_with_docstring"
assert another_tool.description == "custom description"
assert another_tool.parameters == {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
}
assert another_tool.function == function_with_docstring

def test_from_function_with_custom_name(self):
tool = Tool.from_function(
function=function_with_docstring,
name="custom_name",
)

assert tool.name == "custom_name"
assert tool.description == "Get weather report for a city."
assert tool.parameters == {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
}
assert tool.function == function_with_docstring

def test_from_function_missing_type_hint(self):
def function_missing_type_hint(city) -> str:
return f"Weather report for {city}: 20°C, sunny"
Expand Down

0 comments on commit 813157d

Please sign in to comment.