Skip to content

Commit e5841a7

Browse files
authored
Improve OpenAPITool corner cases handling (missing operationId, servers under paths, etc) (#37)
* Improve corner cases handling * Add unit tests for server order resolution and missing operationId
1 parent 6c62db7 commit e5841a7

File tree

5 files changed

+149
-65
lines changed

5 files changed

+149
-65
lines changed

haystack_experimental/components/tools/openapi/_schema_conversion.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
from haystack.lazy_imports import LazyImport
99

10-
from haystack_experimental.components.tools.openapi.types import OpenAPISpecification
10+
from haystack_experimental.components.tools.openapi.types import (
11+
VALID_HTTP_METHODS,
12+
OpenAPISpecification,
13+
path_to_operation_id,
14+
)
1115

1216
with LazyImport("Run 'pip install jsonref'") as jsonref_import:
1317
# pylint: disable=import-error
@@ -96,11 +100,14 @@ def _openapi_to_functions(
96100
f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
97101
)
98102
functions: List[Dict[str, Any]] = []
99-
for paths in service_openapi_spec["paths"].values():
100-
for path_spec in paths.values():
101-
function_dict = parse_endpoint_fn(path_spec, parameters_name)
102-
if function_dict:
103-
functions.append(function_dict)
103+
for path, path_value in service_openapi_spec["paths"].items():
104+
for path_key, operation_spec in path_value.items():
105+
if path_key.lower() in VALID_HTTP_METHODS:
106+
if "operationId" not in operation_spec:
107+
operation_spec["operationId"] = path_to_operation_id(path, path_key)
108+
function_dict = parse_endpoint_fn(operation_spec, parameters_name)
109+
if function_dict:
110+
functions.append(function_dict)
104111
return functions
105112

106113

haystack_experimental/components/tools/openapi/types.py

+50-57
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,24 @@
2323
]
2424

2525

26+
def path_to_operation_id(path: str, http_method: str = "get") -> str:
27+
"""
28+
Converts a path to an operationId.
29+
30+
:param path: The path to convert.
31+
:param http_method: The HTTP method to use for the operationId.
32+
:returns: The operationId.
33+
"""
34+
if http_method.lower() not in VALID_HTTP_METHODS:
35+
raise ValueError(f"Invalid HTTP method: {http_method}")
36+
return path.replace("/", "_").lstrip("_").rstrip("_") + "_" + http_method.lower()
37+
38+
2639
class LLMProvider(Enum):
2740
"""
2841
LLM providers supported by `OpenAPITool`.
2942
"""
43+
3044
OPENAI = "openai"
3145
ANTHROPIC = "anthropic"
3246
COHERE = "cohere"
@@ -50,18 +64,18 @@ def from_str(string: str) -> "LLMProvider":
5064
@dataclass
5165
class Operation:
5266
"""
53-
Represents an operation in an OpenAPI specification
54-
55-
See https://spec.openapis.org/oas/latest.html#paths-object for details.
56-
Path objects can contain multiple operations, each with a unique combination of path and method.
57-
58-
:param path: Path of the operation.
59-
:param method: HTTP method of the operation.
60-
:param operation_dict: Operation details from OpenAPI spec
61-
:param spec_dict: The encompassing OpenAPI specification.
62-
:param security_requirements: A list of security requirements for the operation.
63-
:param request_body: Request body details.
64-
:param parameters: Parameters for the operation.
67+
Represents an operation in an OpenAPI specification
68+
69+
See https://spec.openapis.org/oas/latest.html#paths-object for details.
70+
Path objects can contain multiple operations, each with a unique combination of path and method.
71+
72+
:param path: Path of the operation.
73+
:param method: HTTP method of the operation.
74+
:param operation_dict: Operation details from OpenAPI spec
75+
:param spec_dict: The encompassing OpenAPI specification.
76+
:param security_requirements: A list of security requirements for the operation.
77+
:param request_body: Request body details.
78+
:param parameters: Parameters for the operation.
6579
"""
6680

6781
path: str
@@ -105,8 +119,12 @@ def get_server(self, server_index: int = 0) -> str:
105119
:returns: The server URL.
106120
:raises ValueError: If no servers are found in the specification.
107121
"""
108-
servers = self.operation_dict.get("servers", []) or self.spec_dict.get(
109-
"servers", []
122+
# servers can be defined at the operation level, path level, or at the root level
123+
# search for servers in the following order: operation, path, root
124+
servers = (
125+
self.operation_dict.get("servers", [])
126+
or self.spec_dict.get("paths", {}).get(self.path, {}).get("servers", [])
127+
or self.spec_dict.get("servers", [])
110128
)
111129
if not servers:
112130
raise ValueError("No servers found in the provided specification.")
@@ -136,11 +154,7 @@ def __init__(self, spec_dict: Dict[str, Any]):
136154
f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}"
137155
)
138156
# just a crude sanity check, by no means a full validation
139-
if (
140-
"openapi" not in spec_dict
141-
or "paths" not in spec_dict
142-
or "servers" not in spec_dict
143-
):
157+
if "openapi" not in spec_dict or "paths" not in spec_dict:
144158
raise ValueError(
145159
"Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.",
146160
spec_dict,
@@ -201,51 +215,30 @@ def from_url(cls, url: str) -> "OpenAPISpecification":
201215
) from e
202216
return cls.from_str(content)
203217

204-
def find_operation_by_id(
205-
self, op_id: str, method: Optional[str] = None
206-
) -> Operation:
218+
def find_operation_by_id(self, op_id: str) -> Operation:
207219
"""
208220
Find an Operation by operationId.
209221
210222
:param op_id: The operationId of the operation.
211-
:param method: The HTTP method of the operation.
212223
:returns: The matching operation
213224
:raises ValueError: If no operation is found with the given operationId.
214225
"""
215-
for path, path_item in self.spec_dict.get("paths", {}).items():
216-
op: Operation = self.get_operation_item(path, path_item, method)
217-
if op_id in op.operation_dict.get("operationId", ""):
218-
return self.get_operation_item(path, path_item, method)
219-
raise ValueError(
220-
f"No operation found with operationId {op_id}, method {method}"
221-
)
222-
223-
def get_operation_item(
224-
self, path: str, path_item: Dict[str, Any], method: Optional[str] = None
225-
) -> Operation:
226-
"""
227-
Gets a particular Operation item from the OpenAPI specification given the path and method.
228-
229-
:param path: The path of the operation.
230-
:param path_item: The path item from the OpenAPI specification.
231-
:param method: The HTTP method of the operation.
232-
:returns: The operation
233-
"""
234-
if method:
235-
operation_dict = path_item.get(method.lower(), {})
236-
if not operation_dict:
237-
raise ValueError(
238-
f"No operation found for method {method} at path {path}"
239-
)
240-
return Operation(path, method.lower(), operation_dict, self.spec_dict)
241-
if len(path_item) == 1:
242-
method, operation_dict = next(iter(path_item.items()))
243-
return Operation(path, method, operation_dict, self.spec_dict)
244-
if len(path_item) > 1:
245-
raise ValueError(
246-
f"Multiple operations found at path {path}, method parameter is required."
247-
)
248-
raise ValueError(f"No operations found at path {path} and method {method}")
226+
for path, path_value in self.spec_dict.get("paths", {}).items():
227+
operations = {
228+
method: operation_dict
229+
for method, operation_dict in path_value.items()
230+
if method.lower() in VALID_HTTP_METHODS
231+
}
232+
233+
for method, operation_dict in operations.items():
234+
if (
235+
operation_dict.get(
236+
"operationId", path_to_operation_id(path, method)
237+
)
238+
== op_id
239+
):
240+
return Operation(path, method, operation_dict, self.spec_dict)
241+
raise ValueError(f"No operation found with operationId {op_id}")
249242

250243
def get_security_schemes(self) -> Dict[str, Dict[str, Any]]:
251244
"""

test/components/tools/openapi/test_openapi_client_edge_cases.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration
8+
from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient
99
from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec
1010

1111

@@ -26,4 +26,29 @@ def test_missing_operation_id(self, test_files_path):
2626
with pytest.raises(ValueError, match="No operation found with operationId"):
2727
client.invoke(payload)
2828

29-
# TODO: Add more tests for edge cases
29+
def test_missing_operation_id_in_operation(self, test_files_path):
30+
"""
31+
Test that the tool definition is generated correctly when the operationId is missing in the specification.
32+
"""
33+
config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"),
34+
request_sender=FastAPITestClient(None))
35+
36+
tools = config.get_tools_definitions(),
37+
tool_def = tools[0][0]
38+
assert tool_def["type"] == "function"
39+
assert tool_def["function"]["name"] == "missing-operation-id_get"
40+
41+
def test_servers_order(self, test_files_path):
42+
"""
43+
Test that servers defined in different locations in the specification are used correctly.
44+
"""
45+
46+
config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"),
47+
request_sender=FastAPITestClient(None))
48+
49+
op = config.openapi_spec.find_operation_by_id("servers-order-path")
50+
assert op.get_server() == "https://inpath.example.com"
51+
op = config.openapi_spec.find_operation_by_id("servers-order-operation")
52+
assert op.get_server() == "https://inoperation.example.com"
53+
op = config.openapi_spec.find_operation_by_id("missing-operation-id_get")
54+
assert op.get_server() == "http://localhost"

test/components/tools/openapi/test_openapi_tool.py

+21
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,24 @@ def test_run_live_cohere(self):
201201
assert isinstance(json_response, dict)
202202
except json.JSONDecodeError:
203203
pytest.fail("Response content is not valid JSON")
204+
205+
@pytest.mark.integration
206+
@pytest.mark.parametrize("provider", ["openai", "anthropic", "cohere"])
207+
def test_run_live_meteo_forecast(self, provider: str):
208+
tool = OpenAPITool(
209+
generator_api=LLMProvider.from_str(provider),
210+
spec="https://raw.githubusercontent.com/open-meteo/open-meteo/main/openapi.yml"
211+
)
212+
results = tool.run(messages=[ChatMessage.from_user(
213+
"weather forecast for latitude 52.52 and longitude 13.41 and set hourly=temperature_2m")])
214+
215+
assert isinstance(results["service_response"], list)
216+
assert len(results["service_response"]) == 1
217+
assert isinstance(results["service_response"][0], ChatMessage)
218+
219+
try:
220+
json_response = json.loads(results["service_response"][0].content)
221+
assert isinstance(json_response, dict)
222+
assert "hourly" in json_response
223+
except json.JSONDecodeError:
224+
pytest.fail("Response content is not valid JSON")

test/test_files/yaml/openapi_edge_cases.yml

+38
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,44 @@ paths:
88
/missing-operation-id:
99
get:
1010
summary: Missing operationId
11+
parameters:
12+
- name: name
13+
in: path
14+
required: true
15+
schema:
16+
type: string
1117
responses:
1218
'200':
1319
description: OK
20+
21+
/servers-order-in-path:
22+
servers:
23+
- url: https://inpath.example.com
24+
get:
25+
summary: Servers order
26+
operationId: servers-order-path
27+
parameters:
28+
- name: name
29+
in: path
30+
required: true
31+
schema:
32+
type: string
33+
responses:
34+
'200':
35+
description: OK
36+
37+
/servers-order-in-operation:
38+
get:
39+
summary: Servers order
40+
operationId: servers-order-operation
41+
parameters:
42+
- name: name
43+
in: path
44+
required: true
45+
schema:
46+
type: string
47+
responses:
48+
'200':
49+
description: OK
50+
servers:
51+
- url: https://inoperation.example.com

0 commit comments

Comments
 (0)