diff --git a/pandagg/node/mappings/abstract.py b/pandagg/node/mappings/abstract.py index 6b7718a1..cb7bc800 100644 --- a/pandagg/node/mappings/abstract.py +++ b/pandagg/node/mappings/abstract.py @@ -10,16 +10,26 @@ class Field(Node): _type_name = "field" KEY = None - def __init__(self, **body): + def __init__(self, multiple=None, nullable=True, **body): + """ + :param multiple: boolean, default None, if True field must be an array, if False field must be a single item + :param nullable: boolean, default True, if False a `None` value will be considered as invalid. + :param body: field body + """ super(Node, self).__init__() self._subfield = body.pop("_subfield", False) self._body = body.copy() + self._multiple = multiple + self._nullable = nullable def line_repr(self, depth, **kwargs): if self.KEY is None: return "_", "" return "", self._display_pattern % self.KEY.capitalize() + def is_valid_value(self, v): + raise NotImplementedError() + @property def body(self): b = self._body.copy() @@ -55,6 +65,9 @@ def __init__(self, **body): self.properties = properties or {} super(ComplexField, self).__init__(**body) + def is_valid_value(self, v): + return isinstance(v, dict) + class RegularField(Field): KEY = None @@ -65,3 +78,7 @@ def __init__(self, **body): raise ValueError("Invalid fields %s" % fields) self.fields = fields super(RegularField, self).__init__(**body) + + def is_valid_value(self, v): + # TODO - implement per field type + return True diff --git a/pandagg/response.py b/pandagg/response.py index 2f97a4ae..51415430 100644 --- a/pandagg/response.py +++ b/pandagg/response.py @@ -158,7 +158,8 @@ def _parse_group_by( row_as_tuple=False, with_single_bucket_groups=False, ): - """Recursive parsing of succession of grouping aggregation clauses. + """ + Recursive parsing of succession of grouping aggregation clauses. Yields each row for which last bucket aggregation generated buckets. """ @@ -203,7 +204,8 @@ def _parse_group_by( yield nrow, nraw_bucket def _normalize_buckets(self, agg_response, agg_name=None): - """Recursive function to parse aggregation response as a normalized entities. + """ + Recursive function to parse aggregation response as a normalized entities. Each response bucket is represented as a dict with keys (key, level, value, children):: { @@ -236,7 +238,8 @@ def _normalize_buckets(self, agg_response, agg_name=None): yield result def _grouping_agg(self, name=None): - """Return aggregation node that used as grouping node. + """ + Return aggregation node that used as grouping node. Note: in case there is only a nested aggregation below that node, group-by nested clause. """ if name is not None: @@ -268,7 +271,8 @@ def to_tabular( normalize=True, with_single_bucket_groups=False, ): - """Build tabular view of ES response grouping levels (rows) until 'grouped_by' aggregation node included is + """ + Build tabular view of ES response grouping levels (rows) until 'grouped_by' aggregation node included is reached, and using children aggregations of grouping level as values for each of generated groups (columns). Suppose an aggregation of this shape (A & B bucket aggregations):: diff --git a/pandagg/tree/mappings.py b/pandagg/tree/mappings.py index d2f47001..3921e08d 100644 --- a/pandagg/tree/mappings.py +++ b/pandagg/tree/mappings.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - - +from pandagg.node import Object, Nested from pandagg.node.mappings.abstract import Field, RegularField, ComplexField @@ -207,3 +206,52 @@ def _insert(self, pid, properties, is_subfield): % field_name ) self._insert(field.identifier, field.fields, True) + + def validate_document(self, d): + self._validate_document(d, pid=self.root) + + def _validate_document(self, d, pid, path=""): + if d is None: + d = {} + if not isinstance(d, dict): + raise ValueError( + "Invalid document type, expected dict, got <%s> at '%s'" + % (type(d), path) + ) + for field_name, field in self.children(pid): + full_path = ".".join([path, field_name]) if path else field_name + field_value = d.get(field_name) + if not field._nullable and not field_value: + raise ValueError("Field <%s> cannot be null" % full_path) + + if field._multiple is True: + if field_value is not None: + if not isinstance(field_value, list): + raise ValueError("Field <%s> should be a array" % full_path) + field_value_list = field_value + else: + field_value_list = [] + if not field._nullable and not any(field_value_list): + # deal with case: [None] + raise ValueError("Field <%s> cannot be null" % full_path) + elif field._multiple is False: + if isinstance(field_value, list): + raise ValueError("Field <%s> should not be an array" % full_path) + field_value_list = [field_value] if field_value else [] + else: + # field._multiple is None -> no restriction + if isinstance(field_value, list): + field_value_list = field_value + else: + field_value_list = [field_value] + + for value in field_value_list: + # nullable check has been done beforehands + if value: + if not field.is_valid_value(value): + raise ValueError( + "Field <%s> value <%s> is not compatible with field of type %s" + % (full_path, value, field.KEY) + ) + if isinstance(field, (Object, Nested)): + self._validate_document(value, field.identifier, path=full_path) diff --git a/tests/tree/mapping/test_mappings.py b/tests/tree/mapping/test_mappings.py index c7a575a0..fcdb2a4f 100644 --- a/tests/tree/mapping/test_mappings.py +++ b/tests/tree/mapping/test_mappings.py @@ -131,3 +131,87 @@ def test_node_path(self): self.assertEqual( mapping_tree.get_path(node.identifier), "local_metrics.dataset.support_test" ) + + def test_validate_doc(self): + tts = [ + { + "name": "non nullable", + "properties": {"pizza": Keyword(nullable=False)}, + "documents_expected_results": [ + ({"pizza": "yolo"}, None), + ({"pizza": None}, "Field cannot be null"), + ({}, "Field cannot be null"), + ({"pizza": ["yo", "lo"]}, None), + ], + }, + { + "name": "nullable", + "properties": {"pizza": Keyword(nullable=True)}, + "documents_expected_results": [ + ({"pizza": "yolo"}, None), + ({"pizza": None}, None), + ({}, None), + ({"pizza": ["yo", "lo"]}, None), + ], + }, + { + "name": "multiple nullable", + "properties": {"pizza": Keyword(multiple=True)}, + "documents_expected_results": [ + ({"pizza": "yolo"}, "Field should be a array"), + ({"pizza": None}, None), + ({}, None), + ({"pizza": ["yo", "lo"]}, None), + ], + }, + { + "name": "multiple non nullable", + "properties": {"pizza": Keyword(multiple=True, nullable=False)}, + "documents_expected_results": [ + ({"pizza": "yolo"}, "Field should be a array"), + ({"pizza": None}, "Field cannot be null"), + ({}, "Field cannot be null"), + ({"pizza": ["yo", "lo"]}, None), + ], + }, + { + "name": "non multiple", + "properties": {"pizza": Keyword(multiple=False)}, + "documents_expected_results": [ + ({"pizza": "yolo"}, None), + ({"pizza": None}, None), + ({}, None), + ({"pizza": ["yo", "lo"]}, "Field should not be an array"), + ], + }, + { + "name": "nested multiple non nullable", + "properties": { + "some_good": Object( + properties={"pizza": Keyword(multiple=True, nullable=False)} + ) + }, + "documents_expected_results": [ + ( + {"some_good": {"pizza": "yolo"}}, + "Field should be a array", + ), + ( + {"some_good": {"pizza": None}}, + "Field cannot be null", + ), + ({}, "Field cannot be null"), + ({"some_good": {"pizza": ["yo", "lo"]}}, None), + ], + }, + ] + for tt in tts: + mappings = Mappings(properties=tt["properties"]) + for doc, expected_error in tt["documents_expected_results"]: + if expected_error: + with self.assertRaises(ValueError, msg=tt["name"]) as e: + mappings.validate_document(doc) + self.assertEqual(e.exception.args, (expected_error,), tt["name"]) + else: + # must not raise error + mappings.validate_document(doc)