Skip to content

Commit

Permalink
fix: deduplication of openapi params (#2788)
Browse files Browse the repository at this point in the history
Previous behavior would consider two parameters with the same name but declared in different places (eg., header, cookie) as an error.

This PR incorporates the "param_in" value when validating params for openapi spec so that it would only be an error to have multiple different parameters, of the same name declared in the same place.

Closes #2662
  • Loading branch information
peterschutt authored Nov 28, 2023
1 parent cb8afc2 commit 84710a1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
16 changes: 8 additions & 8 deletions litestar/_openapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, route_handler: BaseRouteHandler) -> None:
route_handler: Associated route handler
"""
self.route_handler = route_handler
self._parameters: dict[str, Parameter] = {}
self._parameters: dict[tuple[str, str], Parameter] = {}

def add(self, parameter: Parameter) -> None:
"""Add a ``Parameter`` to the collection.
Expand All @@ -50,18 +50,18 @@ def add(self, parameter: Parameter) -> None:
``ImproperlyConfiguredException``.
"""

if parameter.name not in self._parameters:
if (parameter.name, parameter.param_in) not in self._parameters:
# because we are defining routes as unique per path, we have to handle here a situation when there is an optional
# path parameter. e.g. get(path=["/", "/{param:str}"]). When parsing the parameter for path, the route handler
# would still have a kwarg called param:
# def handler(param: str | None) -> ...
if parameter.param_in != ParamType.QUERY or all(
"{" + parameter.name + ":" not in path for path in self.route_handler.paths
f"{{{parameter.name}:" not in path for path in self.route_handler.paths
):
self._parameters[parameter.name] = parameter
self._parameters[(parameter.name, parameter.param_in)] = parameter
return

pre_existing = self._parameters[parameter.name]
pre_existing = self._parameters[(parameter.name, parameter.param_in)]
if parameter == pre_existing:
return

Expand Down Expand Up @@ -206,13 +206,13 @@ def create_parameter_for_handler(
dependency_providers = route_handler.resolve_dependencies()
layered_parameters = route_handler.resolve_layered_parameters()

unique_handler_fields = tuple(
unique_handler_fields = (
(k, v) for k, v in handler_fields.items() if k not in RESERVED_KWARGS and k not in layered_parameters
)
unique_layered_fields = tuple(
unique_layered_fields = (
(k, v) for k, v in layered_parameters.items() if k not in RESERVED_KWARGS and k not in handler_fields
)
intersection_fields = tuple(
intersection_fields = (
(k, v) for k, v in handler_fields.items() if k not in RESERVED_KWARGS and k in layered_parameters
)

Expand Down
35 changes: 33 additions & 2 deletions tests/unit/test_openapi/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Generic, Optional, TypeVar
from typing import Generic, Optional, TypeVar, cast

import msgspec
import pytest
import yaml
from typing_extensions import Annotated

from litestar import Controller, get, post
from litestar import Controller, Litestar, get, post
from litestar.app import DEFAULT_OPENAPI_CONFIG
from litestar.enums import MediaType, OpenAPIMediaType, ParamType
from litestar.openapi import OpenAPIConfig, OpenAPIController
from litestar.openapi.spec import Parameter as OpenAPIParameter
from litestar.params import Parameter
from litestar.serialization.msgspec_hooks import decode_json, encode_json, get_serializer
from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND
from litestar.testing import create_test_client
Expand Down Expand Up @@ -291,3 +293,32 @@ def handler_foo_int() -> Foo[int]:
}
},
}


def test_allow_multiple_parameters_with_same_name_but_different_location() -> None:
"""Test that we can support params with the same name if they are in different locations, e.g., cookie and header.
https://github.com/litestar-org/litestar/issues/2662
"""

@post("/test")
async def route(
name: Annotated[Optional[str], Parameter(cookie="name")] = None, # noqa: UP007
name_header: Annotated[Optional[str], Parameter(header="name")] = None, # noqa: UP007
) -> str:
return name or name_header or ""

app = Litestar(route_handlers=[route], debug=True)
assert app.openapi_schema.paths is not None
schema = app.openapi_schema
paths = schema.paths
assert paths is not None
path = paths["/test"]
assert path.post is not None
parameters = path.post.parameters
assert parameters is not None
assert len(parameters) == 2
assert all(isinstance(param, OpenAPIParameter) for param in parameters)
params = cast("list[OpenAPIParameter]", parameters)
assert all(param.name == "name" for param in params)
assert tuple(param.param_in for param in params) == ("cookie", "header")

0 comments on commit 84710a1

Please sign in to comment.