Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(wip) Easier way to add IDs to custom tool schema body #2534

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"

# defaults to False
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/tools/custom/custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def build_custom_tools_from_openapi_schema_and_headers(

url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)

return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/tools/custom/openapi_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class MethodSpec(BaseModel):
summary: str
path: str
method: str
body_schema: dict[str, Any] = {}
spec: dict[str, Any]

def get_request_body_schema(self) -> dict[str, Any]:
Expand Down Expand Up @@ -87,6 +88,8 @@ def to_tool_definition(self) -> dict[str, Any]:
tool_definition["function"]["parameters"]["properties"].update(
{param["name"]: param["schema"] for param in path_param_schemas}
)
print(tool_definition)
print("")
return tool_definition

def validate_spec(self) -> None:
Expand Down
219 changes: 219 additions & 0 deletions backend/tests/unit/danswer/tools/custom/test_custom_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import unittest
from unittest.mock import patch

import pytest

from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.custom.custom_tool import validate_openapi_schema
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import ToolResponse


class TestCustomTool(unittest.TestCase):
"""
Test suite for CustomTool functionality.
This class tests the creation, running, and result handling of custom tools
based on OpenAPI schemas.
"""

def setUp(self):
"""
Set up the test environment before each test method.
Initializes an OpenAPI schema and DynamicSchemaInfo for testing.
"""
self.openapi_schema = {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Assistants API",
"description": "An API for managing assistants",
},
"servers": [
{"url": "http://localhost:8080/CHAT_SESSION_ID/test/MESSAGE_ID"},
],
"paths": {
"/assistant/{assistant_id}": {
"GET": {
"summary": "Get a specific Assistant",
"operationId": "getAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
},
"POST": {
"summary": "Create a new Assistant",
"operationId": "createAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
"requestBody": {
"required": True,
"content": {
"application/json": {"schema": {"type": "object"}}
},
},
},
}
},
}
validate_openapi_schema(self.openapi_schema)
self.dynamic_schema_info = DynamicSchemaInfo(chat_session_id=10, message_id=20)

@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_run_get(self, mock_request):
"""
Test the GET method of a custom tool.
Verifies that the tool correctly constructs the URL and makes the GET request.
"""
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)

result = list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
mock_request.assert_called_once_with("GET", expected_url, json=None, headers={})

self.assertEqual(
len(result), 1, "Expected exactly one result from the tool run"
)
self.assertEqual(
result[0].id,
CUSTOM_TOOL_RESPONSE_ID,
"Tool response ID does not match expected value",
)
self.assertEqual(
result[0].response.tool_name,
"getAssistant",
"Tool name in response does not match expected value",
)

@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_run_post(self, mock_request):
"""
Test the POST method of a custom tool.
Verifies that the tool correctly constructs the URL and makes the POST request with the given body.
"""
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)

result = list(tools[1].run(assistant_id="456"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/456"
mock_request.assert_called_once_with(
"POST", expected_url, json=None, headers={}
)

self.assertEqual(
len(result), 1, "Expected exactly one result from the tool run"
)
self.assertEqual(
result[0].id,
CUSTOM_TOOL_RESPONSE_ID,
"Tool response ID does not match expected value",
)
self.assertEqual(
result[0].response.tool_name,
"createAssistant",
"Tool name in response does not match expected value",
)

@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_with_headers(self, mock_request):
"""
Test the custom tool with custom headers.
Verifies that the tool correctly includes the custom headers in the request.
"""
custom_headers = [
{"key": "Authorization", "value": "Bearer token123"},
{"key": "Custom-Header", "value": "CustomValue"},
]
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema,
custom_headers=custom_headers,
dynamic_schema_info=self.dynamic_schema_info,
)

list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
expected_headers = {
"Authorization": "Bearer token123",
"Custom-Header": "CustomValue",
}
mock_request.assert_called_once_with(
"GET", expected_url, json=None, headers=expected_headers
)

@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_with_empty_headers(self, mock_request):
"""
Test the custom tool with an empty list of custom headers.
Verifies that the tool correctly handles an empty list of headers.
"""
custom_headers = []
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema,
custom_headers=custom_headers,
dynamic_schema_info=self.dynamic_schema_info,
)

list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
mock_request.assert_called_once_with("GET", expected_url, json=None, headers={})

def test_invalid_openapi_schema(self):
"""
Test that an invalid OpenAPI schema raises a ValueError.
"""
invalid_schema = {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Invalid API",
},
# Missing required 'paths' key
}

with self.assertRaises(ValueError) as _:
validate_openapi_schema(invalid_schema)

def test_custom_tool_final_result(self):
"""
Test the final_result method of a custom tool.
Verifies that the method correctly extracts and returns the tool result.
"""
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)

mock_response = ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name="getAssistant",
tool_result={"id": "789", "name": "Final Assistant"},
),
)

final_result = tools[0].final_result(mock_response)
self.assertEqual(
final_result,
{"id": "789", "name": "Final Assistant"},
"Final result does not match expected output",
)


if __name__ == "__main__":
pytest.main([__file__])
2 changes: 1 addition & 1 deletion deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

Expand Down
2 changes: 1 addition & 1 deletion deployment/docker_compose/docker-compose.gpu-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data

Expand Down
Loading