Skip to content

Commit

Permalink
Tin/better union hooks (#499)
Browse files Browse the repository at this point in the history
* Improve union structure hook handling

* Improve typeddict coverage

* Skip test on 3.9 and 3.10
  • Loading branch information
Tinche authored Feb 10, 2024
1 parent 066ace9 commit 4f4a6e9
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 46 deletions.
34 changes: 14 additions & 20 deletions src/cattrs/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
IterableValidationNote,
StructureHandlerNotFoundError,
)
from .fns import identity, raise_error
from .fns import Predicate, identity, raise_error
from .gen import (
AttributeOverride,
DictStructureFn,
Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
self._prefer_attrib_converters = prefer_attrib_converters

self.detailed_validation = detailed_validation
self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {}

# Create a per-instance cache.
if unstruct_strat is UnstructureStrategy.AS_DICT:
Expand Down Expand Up @@ -246,7 +247,8 @@ def __init__(
(is_supported_union, self._gen_attrs_union_structure, True),
(
lambda t: is_union_type(t) and t in self._union_struct_registry,
self._structure_union,
self._union_struct_registry.__getitem__,
True,
),
(is_optional, self._structure_optional),
(has, self._structure_attrs),
Expand All @@ -266,9 +268,6 @@ def __init__(

self._dict_factory = dict_factory

# Unions are instances now, not classes. We use different registries.
self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {}

self._unstruct_copy_skip = self._unstructure_func.get_num_fns()
self._struct_copy_skip = self._structure_func.get_num_fns()

Expand Down Expand Up @@ -330,7 +329,7 @@ def register_unstructure_hook(
return None

def register_unstructure_hook_func(
self, check_func: Callable[[Any], bool], func: UnstructureHook
self, check_func: Predicate, func: UnstructureHook
) -> None:
"""Register a class-to-primitive converter function for a class, using
a function to check if it's a match.
Expand All @@ -339,25 +338,25 @@ def register_unstructure_hook_func(

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool]
self, predicate: Predicate
) -> Callable[[UnstructureHookFactory], UnstructureHookFactory]:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool]
self, predicate: Predicate
) -> Callable[[ExtendedUnstructureHookFactory], ExtendedUnstructureHookFactory]:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool], factory: UnstructureHookFactory
self, predicate: Predicate, factory: UnstructureHookFactory
) -> UnstructureHookFactory:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool], factory: ExtendedUnstructureHookFactory
self, predicate: Predicate, factory: ExtendedUnstructureHookFactory
) -> ExtendedUnstructureHookFactory:
...

Expand Down Expand Up @@ -473,7 +472,7 @@ def register_structure_hook(
self._structure_func.register_cls_list([(cl, func)])

def register_structure_hook_func(
self, check_func: Callable[[type[T]], bool], func: StructureHook
self, check_func: Predicate, func: StructureHook
) -> None:
"""Register a class-to-primitive converter function for a class, using
a function to check if it's a match.
Expand All @@ -482,25 +481,25 @@ def register_structure_hook_func(

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any, bool]]
self, predicate: Predicate
) -> Callable[[StructureHookFactory, StructureHookFactory]]:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any, bool]]
self, predicate: Predicate
) -> Callable[[ExtendedStructureHookFactory, ExtendedStructureHookFactory]]:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any], bool], factory: StructureHookFactory
self, predicate: Predicate, factory: StructureHookFactory
) -> StructureHookFactory:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any], bool], factory: ExtendedStructureHookFactory
self, predicate: Predicate, factory: ExtendedStructureHookFactory
) -> ExtendedStructureHookFactory:
...

Expand Down Expand Up @@ -903,11 +902,6 @@ def _structure_optional(self, obj, union):
# We can't actually have a Union of a Union, so this is safe.
return self._structure_func.dispatch(other)(obj, other)

def _structure_union(self, obj, union):
"""Deal with structuring a union."""
handler = self._union_struct_registry[union]
return handler(obj, union)

def _structure_tuple(self, obj: Any, tup: type[T]) -> T:
"""Deal with structuring into a tuple."""
tup_params = None if tup in (Tuple, tuple) else tup.__args__
Expand Down
17 changes: 6 additions & 11 deletions src/cattrs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from attrs import Factory, define

from ._compat import TypeAlias
from .fns import Predicate

if TYPE_CHECKING:
from .converters import BaseConverter

T = TypeVar("T")

TargetType: TypeAlias = Any
UnstructuredValue: TypeAlias = Any
StructuredValue: TypeAlias = Any
Expand Down Expand Up @@ -46,12 +45,12 @@ class FunctionDispatch:

_converter: BaseConverter
_handler_pairs: list[
tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool, bool]
tuple[Predicate, Callable[[Any, Any], Any], bool, bool]
] = Factory(list)

def register(
self,
predicate: Callable[[Any], bool],
predicate: Predicate,
func: Callable[..., Any],
is_generator=False,
takes_converter=False,
Expand Down Expand Up @@ -148,13 +147,9 @@ def register_cls_list(self, cls_and_handler, direct: bool = False) -> None:
def register_func_list(
self,
pred_and_handler: list[
tuple[Callable[[Any], bool], Any]
| tuple[Callable[[Any], bool], Any, bool]
| tuple[
Callable[[Any], bool],
Callable[[Any, BaseConverter], Any],
Literal["extended"],
]
tuple[Predicate, Any]
| tuple[Predicate, Any, bool]
| tuple[Predicate, Callable[[Any, BaseConverter], Any], Literal["extended"]]
],
):
"""
Expand Down
6 changes: 5 additions & 1 deletion src/cattrs/fns.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Useful internal functions."""
from typing import NoReturn, Type, TypeVar
from typing import Any, Callable, NoReturn, Type, TypeVar

from ._compat import TypeAlias
from .errors import StructureHandlerNotFoundError

T = TypeVar("T")

Predicate: TypeAlias = Callable[[Any], bool]
"""A predicate function determines if a type can be handled."""


def identity(obj: T) -> T:
"""The identity function."""
Expand Down
14 changes: 3 additions & 11 deletions src/cattrs/gen/typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,9 @@ def _required_keys(cls: type) -> set[str]:
# gathering required keys. *sigh*
own_annotations = cls.__dict__.get("__annotations__", {})
required_keys = set()
for base in cls.__mro__[1:]:
if base in (object, dict):
# These have no required keys for sure.
continue
required_keys |= _required_keys(base)
# On 3.8 - 3.10, typing.TypedDict doesn't put typeddict superclasses
# in the MRO, therefore we cannot handle non-required keys properly
# in some situations. Oh well.
for key in getattr(cls, "__required_keys__", []):
annotation_type = own_annotations[key]
annotation_origin = get_origin(annotation_type)
Expand Down Expand Up @@ -597,13 +595,7 @@ def _required_keys(cls: type) -> set[str]:

own_annotations = cls.__dict__.get("__annotations__", {})
required_keys = set()
superclass_keys = set()
for base in cls.__mro__[1:]:
required_keys |= _required_keys(base)
superclass_keys |= base.__dict__.get("__annotations__", {}).keys()
for key in own_annotations:
if key in superclass_keys:
continue
annotation_type = own_annotations[key]

if is_annotated(annotation_type):
Expand Down
2 changes: 2 additions & 0 deletions tests/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys

is_py38 = sys.version_info[:2] == (3, 8)
is_py39 = sys.version_info[:2] == (3, 9)
is_py39_plus = sys.version_info >= (3, 9)
is_py310 = sys.version_info[:2] == (3, 10)
is_py310_plus = sys.version_info >= (3, 10)
is_py311_plus = sys.version_info >= (3, 11)
is_py312_plus = sys.version_info >= (3, 12)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hypothesis import assume, given
from hypothesis.strategies import booleans
from pytest import raises
from typing_extensions import NotRequired
from typing_extensions import NotRequired, Required

from cattrs import BaseConverter, Converter
from cattrs._compat import ExtensionsTypedDict, get_notrequired_base, is_generic
Expand All @@ -24,7 +24,7 @@
make_dict_unstructure_fn,
)

from ._compat import is_py38, is_py311_plus
from ._compat import is_py38, is_py39, is_py310, is_py311_plus
from .typeddicts import (
generic_typeddicts,
simple_typeddicts,
Expand Down Expand Up @@ -263,6 +263,28 @@ def test_required(
assert restructured == instance


@pytest.mark.skipif(is_py39 or is_py310, reason="Sigh")
def test_required_keys() -> None:
"""We don't support the full gamut of functionality on 3.8.
When using `typing.TypedDict` we have only partial functionality;
this test tests only a subset of this.
"""
c = mk_converter()

class Base(TypedDict, total=False):
a: Required[datetime]

class Sub(Base):
b: int

fn = make_dict_unstructure_fn(Sub, c)

with raises(KeyError):
# This needs to raise since 'a' is missing, and it's Required.
fn({"b": 1})


@given(simple_typeddicts(min_attrs=1, total=True), booleans())
def test_omit(cls_and_instance: Tuple[type, Dict], detailed_validation: bool) -> None:
"""`override(omit=True)` works."""
Expand Down
2 changes: 1 addition & 1 deletion tests/typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def simple_typeddicts(
note(
"\n".join(
[
"class HypTypedDict(TypedDict):",
f"class HypTypedDict(TypedDict{'' if total else ', total=False'}):",
*[f" {n}: {a}" for n, a in attrs_dict.items()],
]
)
Expand Down

0 comments on commit 4f4a6e9

Please sign in to comment.