Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enum OAS generation (#3518) #3525

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum, EnumMeta
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -41,9 +41,7 @@
)
from litestar._openapi.schema_generation.utils import (
_get_normalized_schema_key,
_should_create_enum_schema,
_should_create_literal_schema,
_type_or_first_not_none_inner_type,
get_json_schema_formatted_examples,
)
from litestar.datastructures import SecretBytes, SecretString, UploadFile
Expand Down Expand Up @@ -181,22 +179,6 @@
return name


def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema:
"""Create a schema instance for an enum.

Args:
annotation: An enum.
include_null: Whether to include null as a possible value.

Returns:
A schema instance.
"""
enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated]
if include_null and None not in enum_values:
enum_values.append(None)
return Schema(type=_types_in_list(enum_values), enum=enum_values)


def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]:
"""Iterate over the flattened arguments of a Literal.

Expand Down Expand Up @@ -327,18 +309,20 @@

if plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
annotation = _type_or_first_not_none_inner_type(field_definition)
result = create_enum_schema(annotation, include_null=field_definition.is_optional)
elif _should_create_literal_schema(field_definition):
annotation = (
make_non_optional_union(field_definition.annotation)
if field_definition.is_optional
else field_definition.annotation
)
result = create_literal_schema(annotation, include_null=field_definition.is_optional)
result = create_literal_schema(
annotation,
include_null=field_definition.is_optional,
)
elif field_definition.is_optional:
result = self.for_optional_field(field_definition)
elif field_definition.is_enum:
result = self.for_enum_field(field_definition)
elif field_definition.is_union:
result = self.for_union_field(field_definition)
elif field_definition.is_type_var:
Expand Down Expand Up @@ -550,6 +534,38 @@
)
return schema

def for_enum_field(
self,
field_definition: FieldDefinition,
) -> Schema | Reference:
"""Create a schema instance for an enum.

Args:
field_definition: A signature field instance.

Returns:
A schema or reference instance.
"""
enum_type: None | OpenAPIType | list[OpenAPIType] = None
if issubclass(field_definition.annotation, str): # StrEnum
enum_type = OpenAPIType.STRING
elif issubclass(field_definition.annotation, int): # IntEnum
enum_type = OpenAPIType.INTEGER

Check warning on line 553 in litestar/_openapi/schema_generation/schema.py

View check run for this annotation

Codecov / codecov/patch

litestar/_openapi/schema_generation/schema.py#L553

Added line #L553 was not covered by tests

enum_values: list[Any] = [v.value for v in field_definition.annotation] # pyright: ignore
if enum_type is None:
enum_type = _types_in_list(enum_values)

key = _get_normalized_schema_key(field_definition.annotation)

schema = self.schema_registry.get_schema_for_key(key)
schema.type = enum_type
schema.enum = enum_values
schema.title = get_name(field_definition.annotation)
schema.description = field_definition.annotation.__doc__

return self.schema_registry.get_reference_for_key(key) or schema

def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference:
if field.kwarg_definition and field.is_const and field.has_default and schema.const is None:
schema.const = field.default
Expand Down
45 changes: 0 additions & 45 deletions litestar/_openapi/schema_generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping, _GenericAlias # type: ignore[attr-defined]

from litestar.utils.helpers import get_name
Expand All @@ -12,55 +11,11 @@
from litestar.typing import FieldDefinition

__all__ = (
"_type_or_first_not_none_inner_type",
"_should_create_enum_schema",
"_should_create_literal_schema",
"_get_normalized_schema_key",
)


def _type_or_first_not_none_inner_type(field_definition: FieldDefinition) -> Any:
"""Get the first inner type that is not None.

This is a narrow focussed utility to be used when we know that a field definition either represents
a single type, or a single type in a union with `None`, and we want the single type.

Args:
field_definition: A field definition instance.

Returns:
A field definition instance.
"""
if not field_definition.is_optional:
return field_definition.annotation
inner = next((t for t in field_definition.inner_types if not t.is_none_type), None)
if inner is None:
raise ValueError("Field definition has no inner type that is not None")
return inner.annotation


def _should_create_enum_schema(field_definition: FieldDefinition) -> bool:
"""Predicate to determine if we should create an enum schema for the field def, or not.

This returns true if the field definition is an enum, or if the field definition is a union
of an enum and ``None``.

When an annotation is ``SomeEnum | None`` we should create a schema for the enum that includes ``null``
in the enum values.

Args:
field_definition: A field definition instance.

Returns:
A boolean
"""
return field_definition.is_subclass_of(Enum) or (
field_definition.is_optional
and len(field_definition.args) == 2
and field_definition.has_inner_subclass_of(Enum)
)


def _should_create_literal_schema(field_definition: FieldDefinition) -> bool:
"""Predicate to determine if we should create a literal schema for the field def, or not.

Expand Down
5 changes: 5 additions & 0 deletions litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import abc, deque
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from enum import Enum
from inspect import Parameter, Signature
from typing import (
Any,
Expand Down Expand Up @@ -432,6 +433,10 @@ def is_typeddict_type(self) -> bool:

return is_typeddict(self.origin or self.annotation)

@property
def is_enum(self) -> bool:
return self.is_subclass_of(Enum)

@property
def type_(self) -> Any:
"""The type of the annotation with all the wrappers removed, including the generic types."""
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_openapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_persons(
from_date: Optional[Union[int, datetime, date]] = None,
to_date: Optional[Union[int, datetime, date]] = None,
gender: Optional[Union[Gender, List[Gender]]] = Parameter(
examples=[Example(value="M"), Example(value=["M", "O"])]
examples=[Example(value=Gender.MALE), Example(value=[Gender.MALE, Gender.OTHER])]
),
# header parameter
secret_header: str = Parameter(header="secret"),
Expand Down
19 changes: 6 additions & 13 deletions tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from litestar.exceptions import ImproperlyConfiguredException
from litestar.handlers import HTTPRouteHandler
from litestar.openapi import OpenAPIConfig
from litestar.openapi.spec import Example, OpenAPI, Schema
from litestar.openapi.spec import Example, OpenAPI, Reference, Schema
from litestar.openapi.spec.enums import OpenAPIType
from litestar.params import Dependency, Parameter
from litestar.routes import BaseRoute
from litestar.testing import create_test_client
from litestar.utils import find_index
from tests.unit.test_openapi.utils import Gender

if TYPE_CHECKING:
from litestar.openapi.spec.parameter import Parameter as OpenAPIParameter
Expand Down Expand Up @@ -104,22 +105,14 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
assert gender.schema == Schema(
one_of=[
Schema(type=OpenAPIType.NULL),
Schema(
type=OpenAPIType.STRING,
enum=["M", "F", "O", "A"],
examples=["M"],
),
Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
Schema(
type=OpenAPIType.ARRAY,
items=Schema(
type=OpenAPIType.STRING,
enum=["M", "F", "O", "A"],
examples=["F"],
),
examples=[["A"]],
items=Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
examples=[[Gender.MALE]],
),
],
examples=["M", ["M", "O"]],
examples=[Gender.MALE, [Gender.MALE, Gender.OTHER]],
)
assert not gender.required

Expand Down
50 changes: 22 additions & 28 deletions tests/unit/test_openapi/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from datetime import date, datetime, timezone
from enum import Enum, auto
from typing import ( # type: ignore[attr-defined]
from typing import (
TYPE_CHECKING,
Any,
Dict,
Expand All @@ -13,8 +13,7 @@
Tuple,
TypedDict,
TypeVar,
Union,
_GenericAlias, # pyright: ignore
Union, # pyright: ignore
)

import annotated_types
Expand All @@ -29,7 +28,7 @@
KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP,
SchemaCreator,
)
from litestar._openapi.schema_generation.utils import _get_normalized_schema_key, _type_or_first_not_none_inner_type
from litestar._openapi.schema_generation.utils import _get_normalized_schema_key
from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar
from litestar.di import Provide
from litestar.enums import ParamType
Expand All @@ -40,7 +39,6 @@
from litestar.pagination import ClassicPagination, CursorPagination, OffsetPagination
from litestar.params import KwargDefinition, Parameter, ParameterKwarg
from litestar.testing import create_test_client
from litestar.types.builtin_types import NoneType
from litestar.typing import FieldDefinition
from litestar.utils.helpers import get_name
from tests.helpers import get_schema_for_field_definition
Expand Down Expand Up @@ -452,12 +450,26 @@ def test_schema_tuple_with_union() -> None:
def test_optional_enum() -> None:
class Foo(Enum):
A = 1
B = 2
B = "b"

schema = get_schema_for_field_definition(FieldDefinition.from_annotation(Optional[Foo]))
assert schema.type is not None
assert set(schema.type) == {OpenAPIType.INTEGER, OpenAPIType.NULL}
assert schema.enum == [1, 2, None]
creator = SchemaCreator(plugins=openapi_schema_plugins)
schema = creator.for_field_definition(FieldDefinition.from_annotation(Optional[Foo]))
assert isinstance(schema, Schema)
assert schema.type is None
assert schema.one_of is not None
null_schema = schema.one_of[0]
assert isinstance(null_schema, Schema)
assert null_schema.type is not None
assert null_schema.type is OpenAPIType.NULL
enum_ref = schema.one_of[1]
assert isinstance(enum_ref, Reference)
assert enum_ref.ref == "#/components/schemas/tests_unit_test_openapi_test_schema_test_optional_enum.Foo"
enum_schema = creator.schema_registry.from_reference(enum_ref).schema
assert enum_schema.type
assert set(enum_schema.type) == {OpenAPIType.INTEGER, OpenAPIType.STRING}
assert enum_schema.enum
assert enum_schema.enum[0] == 1
assert enum_schema.enum[1] == "b"


def test_optional_literal() -> None:
Expand All @@ -467,24 +479,6 @@ def test_optional_literal() -> None:
assert schema.enum == [1, None]


@pytest.mark.parametrize(
("in_type", "out_type"),
[
(FieldDefinition.from_annotation(Optional[int]), int),
(FieldDefinition.from_annotation(Union[None, int]), int),
(FieldDefinition.from_annotation(int), int),
# hack to create a union of NoneType, NoneType to hit a branch for coverage
(FieldDefinition.from_annotation(_GenericAlias(Union, (NoneType, NoneType))), ValueError),
],
)
def test_type_or_first_not_none_inner_type_utility(in_type: Any, out_type: Any) -> None:
if out_type is ValueError:
with pytest.raises(out_type):
_type_or_first_not_none_inner_type(in_type)
else:
assert _type_or_first_not_none_inner_type(in_type) == out_type


def test_not_generating_examples_property() -> None:
with_examples = SchemaCreator(generate_examples=True)
without_examples = with_examples.not_generating_examples
Expand Down
Loading