diff --git a/marshmallow_jsonapi/query_fields.py b/marshmallow_jsonapi/query_fields.py new file mode 100644 index 0000000..63ff1d0 --- /dev/null +++ b/marshmallow_jsonapi/query_fields.py @@ -0,0 +1,150 @@ +""" +Includes fields designed solely for parsing query/URL parameters from JSON API requests +""" +import typing +from enum import Enum + +import marshmallow as ma +import querystring_parser.parser as qsp +from webargs import core, fields +from webargs.fields import DelimitedList, String, Dict + + +class NestedQueryParserMixin: + """ + Mixin for creating a JSON API-compatible parser from a regular Webargs parser + + Examples: :: + from marshmallow_jsonapi.query_fields import NestedQueryParserMixin, JsonApiRequestSchema + from webargs.flaskparser import FlaskParser + + class FlaskJsonApiParser(FlaskParser, NestedQueryParserMixin): + pass + + parser = FlaskJsonApiParser() + + @parser.use_args(JsonApiRequestSchema()) + def greet(args): + return 'You requested to include these relationships: ' + ', '.join(args['include']) + """ + + def parse_querystring(self, req, name, field): + return core.get_value(qsp.parse(req.query_string), name, field) + + +class SortDirection(Enum): + """ + The direction to sort a field by + """ + + ASCENDING = 1 + DESCENDING = 2 + + +class SortItem(typing.NamedTuple): + """ + Represents a single entry in the list of fields to sort by + """ + + field: str + direction: SortDirection + + +class SortField(fields.Field): + """ + Marshmallow field that parses and dumps a JSON API sort parameter + """ + + def _serialize(self, value, attr, obj, **kwargs): + if value.direction == SortDirection.DESCENDING: + return "-" + value.field + else: + return value.field + + def _deserialize(self, value, attr, data, **kwargs): + if value.startswith("-"): + return SortItem(value[1:], SortDirection.DESCENDING) + else: + return SortItem(value, SortDirection.ASCENDING) + + +class PagePaginationSchema(ma.Schema): + number = fields.Integer() + size = fields.Integer() + + +class OffsetPaginationSchema(ma.Schema): + offset = fields.Integer() + limit = fields.Integer() + + +class Include(DelimitedList): + """ + The value of the include parameter MUST be a comma-separated (U+002C COMMA, “,”) list of relationship paths. + A relationship path is a dot-separated (U+002E FULL-STOP, “.”) list of relationship names. + + .. seealso:: + `JSON API Specification, Inclusion of Related Resources `_ + JSON API specification for the include request parameter + """ + + def __init__(self): + super().__init__(String(), data_key="include", delimiter=",", as_string=True) + + +class Fields(Dict): + """ + The value of the fields parameter MUST be a comma-separated (U+002C COMMA, “,”) list that refers to the name(s) of + the fields to be returned. + + .. seealso:: + `JSON API Specification, Sparse Fieldsets `_ + JSON API specification for the fields request parameter + """ + + def __init__(self): + super().__init__( + keys=String(), + values=DelimitedList(String(), delimiter=",", as_string=True), + data_key="fields", + ) + + +class Sort(DelimitedList): + """ + An endpoint MAY support requests to sort the primary data with a sort query parameter. + The value for sort MUST represent sort fields. + An endpoint MAY support multiple sort fields by allowing comma-separated (U+002C COMMA, “,”) sort fields. + Sort fields SHOULD be applied in the order specified. + + .. seealso:: + `JSON API Specification, Sorting `_ + JSON API specification for the sort request parameter + """ + + def __init__(self): + super().__init__(SortField(), data_key="sort", delimiter=",", as_string=True) + + +class Filter(Dict): + def __init__(self): + super().__init__( + keys=String(), + values=DelimitedList(String(), delimiter=",", as_string=True), + data_key="filter", + ) + + +class PagePagination(fields.Nested): + def __init__(self): + super().__init__(PagePaginationSchema(), data_key="page") + + +class OffsetPagination(fields.Nested): + def __init__(self): + super().__init__(OffsetPaginationSchema(), data_key="page") + + +class CursorPagination(fields.Nested): + def __init__(self, cursor_field): + super().__init__(core.dict2schema({"cursor": cursor_field}), data_key="page") diff --git a/setup.py b/setup.py index 75d095e..c67c229 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,11 @@ import re from setuptools import setup, find_packages -INSTALL_REQUIRES = ("marshmallow>=2.15.2",) +INSTALL_REQUIRES = ( + "marshmallow>=2.15.2", + "webargs>=5.5.1", + "querystring-parser>=1.2.4", +) EXTRAS_REQUIRE = { "tests": ["pytest", "mock", "faker==2.0.2", "Flask==1.1.1"], "lint": ["flake8==3.7.8", "flake8-bugbear==19.8.0", "pre-commit~=1.18"], diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..4184654 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,192 @@ +from typing import NamedTuple + +import pytest +from marshmallow import fields, Schema +from webargs.core import Parser, MARSHMALLOW_VERSION_INFO + +from marshmallow_jsonapi import query_fields as qf + + +class CompleteSchema(Schema): + sort = qf.Sort() + include = qf.Include() + fields = qf.Fields() + page = qf.PagePagination() + filter = qf.Filter() + + +class MockRequest(NamedTuple): + """ + A fake request object that has only a query string + """ + + query_string: str + + +class TestQueryParser: + def test_nested_field(self): + """ + Check that the query string parser can do what JSON API demands of it: parsing `param[key]` into a dictionary + """ + parser = qf.NestedQueryParserMixin() + request = MockRequest( + "include=author&fields[articles]=title,body,author&fields[people]=name" + ) + + assert parser.parse_querystring(request, "include", None) == "author" + assert parser.parse_querystring(request, "fields", None) == { + "articles": "title,body,author", + "people": "name", + } + + +@pytest.mark.parametrize( + ("field", "serialized", "deserialized"), + ( + ( + qf.SortField(), + "title", + qf.SortItem(field="title", direction=qf.SortDirection.ASCENDING), + ), + ( + qf.SortField(), + "-title", + qf.SortItem(field="title", direction=qf.SortDirection.DESCENDING), + ), + (qf.Include(), "author,comments.author", ["author", "comments.author"]), + ( + qf.Fields(), + {"articles": "title,body", "people": "name"}, + {"articles": ["title", "body"], "people": ["name"]}, + ), + ( + qf.Sort(), + "-created,title", + [ + qf.SortItem(field="created", direction=qf.SortDirection.DESCENDING), + qf.SortItem(field="title", direction=qf.SortDirection.ASCENDING), + ], + ), + (qf.PagePagination(), {"number": 3, "size": 1}, {"number": 3, "size": 1}), + (qf.OffsetPagination(), {"offset": 3, "limit": 1}, {"offset": 3, "limit": 1}), + ( + qf.CursorPagination(fields.Integer()), + {"cursor": -1}, + {"cursor": -1}, + ), # A Twitter-api style cursor + ( + qf.Filter(), + {"post": "1,2", "author": "12"}, + {"post": ["1", "2"], "author": ["12"]}, + ), + ), +) +def test_serialize_deserialize_field(field, serialized, deserialized): + """ + Tests all new fields, ensuring they serialize and deserialize as expected + :param field: + :param serialized: + :param deserialized: + :return: + """ + if isinstance(field, fields.Dict) and MARSHMALLOW_VERSION_INFO[0] < 3: + pytest.skip("Marshmallow<3 doesn't support dictionary deserialization") + + assert field.serialize("some_field", dict(some_field=deserialized)) == serialized + assert field.deserialize(serialized) == deserialized + + +class TestPagePaginationSchema: + def test_validate(self): + schema = qf.PagePaginationSchema() + assert schema.validate({"number": 3, "size": 1}) == {} + + +class TestOffsetPagePaginationSchema: + def test_validate(self): + schema = qf.OffsetPaginationSchema() + assert schema.validate({"offset": 3, "limit": 1}) == {} + + +class TestCompleteSchema: + def test_validate(self): + schema = CompleteSchema() + + assert ( + schema.validate( + { + "sort": "-created,title", + "include": "author,comments.author", + "fields": {"articles": "title,body", "people": "name"}, + "page": {"number": 3, "size": 1}, + "filter": {"post": "1,2", "author": "12"}, + } + ) + == {} + ) + + +@pytest.mark.skipif( + MARSHMALLOW_VERSION_INFO[0] < 3, + reason="Marshmallow<3 doesn't support dictionary deserialization", +) +@pytest.mark.parametrize( + ("query", "expected"), + ( + ("include=author", {"include": ["author"]}), + ( + "include=author&fields[articles]=title,body,author&fields[people]=name", + { + "fields": {"articles": ["title", "body", "author"], "people": ["name"]}, + "include": ["author"], + }, + ), + ( + "include=author&fields[articles]=title,body&fields[people]=name", + { + "fields": {"articles": ["title", "body"], "people": ["name"]}, + "include": ["author"], + }, + ), + ("page[number]=3&page[size]=1", {"page": {"size": 1, "number": 3}}), + ("include=comments.author", {"include": ["comments.author"]}), + ( + "sort=age", + {"sort": [qf.SortItem(field="age", direction=qf.SortDirection.ASCENDING)]}, + ), + ( + "sort=age,name", + { + "sort": [ + qf.SortItem(field="age", direction=qf.SortDirection.ASCENDING), + qf.SortItem(field="name", direction=qf.SortDirection.ASCENDING), + ] + }, + ), + ( + "sort=-created,title", + { + "sort": [ + qf.SortItem(field="created", direction=qf.SortDirection.DESCENDING), + qf.SortItem(field="title", direction=qf.SortDirection.ASCENDING), + ] + }, + ), + ), +) +def test_jsonapi_examples(query, expected): + """ + Tests example query strings from the JSON API specification + """ + request = MockRequest(query) + + class TestParser(qf.NestedQueryParserMixin, Parser): + pass + + parser = TestParser() + + @parser.use_args(CompleteSchema(), locations=("query",), req=request) + def handle(args): + return args + + assert handle() == expected