Skip to content

Commit

Permalink
Merge pull request #201 from scrapinghub/annotated-support
Browse files Browse the repository at this point in the history
Support annotated deps in serialization and testing.
  • Loading branch information
wRAR authored Mar 4, 2024
2 parents 7b4e6a5 + 52371ec commit 46d3a4e
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/api-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ Fields
.. automodule:: web_poet.fields
:members:

typing.Annotated support
========================

.. automodule:: web_poet.annotated
:members:

Utils
=====
Expand Down
46 changes: 46 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import Type

import attrs
Expand All @@ -14,6 +15,7 @@
Stats,
WebPage,
)
from web_poet.annotated import AnnotatedInstance
from web_poet.page_inputs.url import _Url
from web_poet.serialization import (
SerializedDataFileStorage,
Expand Down Expand Up @@ -214,3 +216,47 @@ def test_httpclient_empty(tmp_path) -> None:
assert (directory / "HttpClient.exists").exists()
read_serialized_deps = storage.read()
assert "HttpClient" in read_serialized_deps


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="No Annotated support in Python < 3.9"
)
def test_annotated(book_list_html_response) -> None:
from typing import Annotated

@attrs.define
class MyWebPage(ItemPage):
response: Annotated[HttpResponse, "foo", 42]
url: ResponseUrl

url_str = "http://books.toscrape.com/index.html"
url = ResponseUrl(url_str)

serialized_deps = serialize(
[AnnotatedInstance(book_list_html_response, ("foo", 42)), url]
)
po = MyWebPage(
book_list_html_response,
url,
)
deserialized_po = deserialize(MyWebPage, serialized_deps)
_assert_pages_equal(po, deserialized_po)


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="No Annotated support in Python < 3.9"
)
def test_annotated_duplicate(book_list_html_response) -> None:
url_str = "http://books.toscrape.com/index.html"
url = ResponseUrl(url_str)

with pytest.raises(
ValueError, match="Several instances of AnnotatedInstance for HttpResponse were"
):
serialize(
[
AnnotatedInstance(book_list_html_response, ("foo", 42)),
AnnotatedInstance(book_list_html_response, ("bar",)),
url,
]
)
26 changes: 26 additions & 0 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from zyte_common_items import Item, Metadata, Product

from web_poet import HttpClient, HttpRequest, HttpResponse, WebPage, field
from web_poet.annotated import AnnotatedInstance
from web_poet.exceptions import HttpRequestError, HttpResponseError, Retry, UseFallback
from web_poet.page_inputs.client import _SavedResponseData
from web_poet.testing import Fixture
Expand Down Expand Up @@ -537,3 +538,28 @@ def test_page_object_exception_none(pytester, book_list_html_response) -> None:
assert fixture.exception_path.exists()
result = pytester.runpytest()
result.assert_outcomes(failed=1)


if sys.version_info >= (3, 9):
from typing import Annotated

@attrs.define(kw_only=True)
class MyAnnotatedItemPage(MyItemPage):
response: Annotated[HttpResponse, "foo", 42]

async def to_item(self) -> dict:
return {"foo": "bar"}


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="No Annotated support in Python < 3.9"
)
def test_annotated(pytester, book_list_html_response) -> None:
_save_fixture(
pytester,
page_cls=MyAnnotatedItemPage,
page_inputs=[AnnotatedInstance(book_list_html_response, ("foo", 42))],
expected_output={"foo": "bar"},
)
result = pytester.runpytest()
result.assert_outcomes(passed=3)
26 changes: 26 additions & 0 deletions web_poet/annotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass
from typing import Any, Tuple


@dataclass
class AnnotatedInstance:
"""Wrapper for instances of annotated dependencies.
It is used when both the dependency value and the dependency annotation are
needed.
:param result: The wrapped dependency instance.
:type result: Any
:param metadata: The copy of the annotation.
:type metadata: Tuple[Any, ...]
"""

result: Any
metadata: Tuple[Any, ...]

def get_annotated_cls(self):
"""Returns a re-created :class:`typing.Annotated` type."""
from typing import Annotated

return Annotated[(type(self.result), *self.metadata)]
32 changes: 26 additions & 6 deletions web_poet/serialization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Any, Callable, Dict, Iterable, Tuple, Type, TypeVar, Union

import andi
from andi.typeutils import strip_annotated

import web_poet
from web_poet import Injectable
from web_poet.annotated import AnnotatedInstance
from web_poet.pages import is_injectable
from web_poet.utils import get_fq_class_name

Expand Down Expand Up @@ -135,7 +137,19 @@ def serialize(deps: Iterable[Any]) -> SerializedData:
cls = dep.__class__
if is_injectable(cls):
raise ValueError(f"Injectable type {cls} passed to serialize()")
result[_get_name_for_class(cls)] = serialize_leaf(dep)
if cls is AnnotatedInstance:
key = f"AnnotatedInstance {_get_name_for_class(dep.result.__class__)}"
else:
key = _get_name_for_class(cls)

if key in result:
cls_name = cls.__name__
if cls is AnnotatedInstance:
cls_name = f"AnnotatedInstance for {dep.result.__class__.__name__}"
raise ValueError(
f"Several instances of {cls_name} were passed to serialize()."
)
result[key] = serialize_leaf(dep)
return result


Expand Down Expand Up @@ -179,15 +193,21 @@ def deserialize(cls: Type[InjectableT], data: SerializedData) -> InjectableT:
deps: Dict[Callable, Any] = {}

for dep_type_name, dep_data in data.items():
dep_type = load_class(dep_type_name)
deps[dep_type] = deserialize_leaf(dep_type, dep_data)
if dep_type_name.startswith("AnnotatedInstance "):
annotated_result = deserialize_leaf(AnnotatedInstance, dep_data)
dep_type = annotated_result.get_annotated_cls()
deserialized_dep = annotated_result.result
else:
dep_type = load_class(dep_type_name)
deserialized_dep = deserialize_leaf(dep_type, dep_data)
deps[dep_type] = deserialized_dep

externally_provided = deps.keys()
externally_provided = {strip_annotated(cls) for cls in deps.keys()}
plan = andi.plan(
cls, is_injectable=is_injectable, externally_provided=externally_provided
)
for fn_or_cls, kwargs_spec in plan[:-1]:
if fn_or_cls in externally_provided:
if strip_annotated(fn_or_cls) in externally_provided:
continue
deps[fn_or_cls] = fn_or_cls(**kwargs_spec.kwargs(deps))
deps[strip_annotated(fn_or_cls)] = fn_or_cls(**kwargs_spec.kwargs(deps))
return cls(**plan.final_kwargs(deps))
34 changes: 33 additions & 1 deletion web_poet/serialization/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional, Type, cast
from typing import Any, Dict, List, Optional, Type, cast

from .. import (
HttpClient,
Expand All @@ -10,12 +10,15 @@
PageParams,
Stats,
)
from ..annotated import AnnotatedInstance
from ..exceptions import HttpError
from ..page_inputs.client import _SavedResponseData
from ..page_inputs.url import _Url
from .api import (
SerializedLeafData,
_get_name_for_class,
deserialize_leaf,
load_class,
register_serialization,
serialize_leaf,
)
Expand Down Expand Up @@ -193,3 +196,32 @@ def _deserialize_Stats(cls: Type[Stats], data: SerializedLeafData) -> Stats:


register_serialization(_serialize_Stats, _deserialize_Stats)


def _serialize_AnnotatedInstance(o: AnnotatedInstance) -> SerializedLeafData:
serialized_data: SerializedLeafData = {
"metadata.json": _format_json(o.metadata).encode(),
"result_type.txt": _get_name_for_class(type(o.result)).encode(),
}
serialized_result = serialize_leaf(o.result)
for k, v in serialized_result.items():
serialized_data["result-" + k] = v
return serialized_data


def _deserialize_AnnotatedInstance(
cls: Type[AnnotatedInstance], data: SerializedLeafData
) -> AnnotatedInstance:
metadata = json.loads(data["metadata.json"])
result_type = load_class(data["result_type.txt"].decode())
serialized_result = {}
for k, v in data.items():
if not k.startswith("result-"):
continue
serialized_result[k.split("-", 1)[1]] = v
result: Any = deserialize_leaf(result_type, serialized_result)

return cls(result=result, metadata=metadata)


register_serialization(_serialize_AnnotatedInstance, _deserialize_AnnotatedInstance)

0 comments on commit 46d3a4e

Please sign in to comment.