diff --git a/examples/subclass.py b/examples/subclass.py index 730d218..f557a34 100644 --- a/examples/subclass.py +++ b/examples/subclass.py @@ -111,9 +111,7 @@ def get( return NoneNode() if single else [NoneNode()] if single and item.__class__.__name__ != cls.__name__: - raise Exception( - f"Found {item.__class__.__name__}, and not {cls.__name__}." - ) + raise Exception(f"Found {item.__class__.__name__}, and not {cls.__name__}.") return item @@ -137,9 +135,7 @@ def make_filter(key, value): else: return fltr.format(value) - filters = [ - make_filter(k, v) for k, v in kwargs.items() if k in cls._filters - ] + filters = [make_filter(k, v) for k, v in kwargs.items() if k in cls._filters] if hasattr(cls, "_required_filters"): filters += list(cls._required_filters) filters = " and ".join((filter(lambda x: x, filters))) @@ -194,10 +190,7 @@ class Person(BaseAbstract): "other": f'eq(gender, "{Gender.OTHER.value}")', }, "family": ("family", "uid(family)"), - "living": { - "true": "(not has(death_year))", - "false": "has(death_year)", - }, + "living": {"true": "(not has(death_year))", "false": "has(death_year)"}, } _subqueries = { "family": """ @@ -284,16 +277,8 @@ class Person(BaseAbstract): @classmethod def _get_parents(cls, person, step=1): - father = ( - Person.get(person.father.uid) - if hasattr(person, "father") - else None - ) - mother = ( - Person.get(person.mother.uid) - if hasattr(person, "mother") - else None - ) + father = Person.get(person.father.uid) if hasattr(person, "father") else None + mother = Person.get(person.mother.uid) if hasattr(person, "mother") else None generation = [ {"person": father, "step": step}, {"person": mother, "step": step}, @@ -317,6 +302,4 @@ def _get_parents(cls, person, step=1): @property def ancestors(self): - return [{"person": self, "step": 0}] + self.__class__._get_parents( - self - ) + return [{"person": self, "step": 0}] + self.__class__._get_parents(self) diff --git a/pydiggy/__init__.py b/pydiggy/__init__.py index 0b7ba1a..ec64d53 100644 --- a/pydiggy/__init__.py +++ b/pydiggy/__init__.py @@ -6,7 +6,7 @@ __email__ = "admhpkns@gmail.com" __version__ = "0.1.0" -from pydiggy.node import Facets, Node, get_node, is_facets +from pydiggy.node import Facets, Node, is_facets, get_node_type, NodeTypeRegistry from pydiggy.operations import generate_mutation, hydrate, query, run_mutation from pydiggy._types import count, exact, geo, index, lang, reverse, uid, upsert @@ -17,12 +17,13 @@ "Facets", "generate_mutation", "geo", - "get_node", + "get_node_type", "hydrate", "is_facets", "index", "lang", "Node", + "NodeTypeRegistry", "query", "reverse", "run_mutation", diff --git a/pydiggy/_types.py b/pydiggy/_types.py index 35ca921..d961c96 100644 --- a/pydiggy/_types.py +++ b/pydiggy/_types.py @@ -26,6 +26,11 @@ class Tokenizer(DirectiveArgument): class Directive: + """ + A directive adds extra instructions to a schema or query. Annotated + with the '@' symbol and optional arguments in parens. + """ + def __str__(self): args = [] if "__annotations__" in self.__class__.__dict__: @@ -78,7 +83,7 @@ def __init__(self, name=None, many=False, with_facets=False): upsert = type("upsert", (Directive,), {}) lang = type("lang", (Directive,), {}) -DGRAPH_TYPES = { # Unsupported dgraph type: password, geo +DGRAPH_TYPES = { # TODO: add dgraph type 'password' "uid": "uid", "geo": "geo", "str": "string", diff --git a/pydiggy/cli.py b/pydiggy/cli.py index 57819f8..745344c 100644 --- a/pydiggy/cli.py +++ b/pydiggy/cli.py @@ -7,7 +7,7 @@ from pydgraph import Operation from pydiggy.connection import get_client -from pydiggy.node import Node +from pydiggy.node import Node, NodeTypeRegistry @click.group() @@ -40,12 +40,12 @@ def generate(module, run, host, port): click.echo(f"Generating schema for: {module}") importlib.import_module(module) - num_nodes = len(Node._nodes) + num_nodes = len(NodeTypeRegistry._node_types) click.echo(f"\nNodes found: ({num_nodes})") - for node in Node._nodes: + for node in NodeTypeRegistry._node_types: click.echo(f" - {node._get_name()}") - schema, unknown = Node._generate_schema() + schema, unknown = NodeTypeRegistry._generate_schema() if not run: click.echo("\nYour schema:\n~~~~~~~~\n") diff --git a/pydiggy/node.py b/pydiggy/node.py index 5458621..45eac82 100644 --- a/pydiggy/node.py +++ b/pydiggy/node.py @@ -1,9 +1,9 @@ from __future__ import annotations import copy -import inspect import json -from collections import namedtuple +import inspect +from collections import namedtuple, ChainMap from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime @@ -69,9 +69,21 @@ def is_computed(node: Node) -> bool: return False +def get_node_type(name: str) -> Node: + """ + Retrieve a registered node class. + + Example: Region = get_node("Region") + + This is a safe method to make sure that any models used have been + registered in the NodeTypeRegistry + """ + registered = {x.__name__: x for x in NodeTypeRegistry._node_types} + return registered.get(name, None) + + def _force_instance( - directive: Union[Directive, reverse, count, upsert, lang], - prop_type: str = None, + directive: Union[Directive, reverse, count, upsert, lang], prop_type: str = None ) -> Directive: # TODO: # - Make sure directive is an instance of, or a class defined as a directive @@ -89,24 +101,9 @@ def _force_instance( return directive(*args) -def get_node(name: str) -> Node: - """ - Retrieve a registered node class. - - Example: Region = get_node("Region") - - This is a safe method to make sure that any models used have been - declared as a Node. - """ - registered = {x.__name__: x for x in Node._nodes} - return registered.get(name, None) - - class NodeMeta(type): def __new__(cls, name, bases, attrs, **kwargs): - directives = [ - x for x in attrs if x in attrs.get("__annotations__", {}).keys() - ] + directives = [x for x in attrs if x in attrs.get("__annotations__", {}).keys()] attrs["_directives"] = dict() attrs["_instances"] = dict() @@ -136,24 +133,43 @@ def __new__(cls, name, bases, attrs, **kwargs): class Node(metaclass=NodeMeta): - uid: int + _instances = dict() + _staged = dict() - _i = _count() - _nodes = [] - _staged = {} + # uid is not used as a class variable, it is part of the required + # schema of a Node, and is thus part of the superclass. + uid : int + + @classmethod + def _get_staged(cls): + return cls._staged + + @classmethod + def _clear_staged(cls): + cls._staged = {} + # NodeTypeRegistry._i needs to be updated when clearing staged nodes + + @classmethod + def _reset(cls) -> None: + cls._instances = dict() + + @classmethod + def _get_name(cls) -> str: + return cls.__name__ def __init_subclass__(cls, is_abstract: bool = False) -> None: if not is_abstract: - cls._register_node(cls) + NodeTypeRegistry._register_node_type(cls) def __init__(self, uid=None, **kwargs): + if uid is None: # TODO: # - There probably should be another property that is set here # so that it is possible to identify with a boolean if the instance # is brand new (and has never been committed to the DB) or if it # is being freshly generated - uid = next(self._generate_uid()) + uid = next(NodeTypeRegistry._generate_uid()) self.uid = uid self._dirty = set() @@ -162,7 +178,7 @@ def __init__(self, uid=None, **kwargs): # - perhaps this code to generate self._annotations belongs somewhere # else. Regardless, there is a lot of cleanup in this module that # could probably use this property instead of running get_type_hints - localns = {x.__name__: x for x in Node._nodes} + localns = {x.__name__: x for x in NodeTypeRegistry._node_types} localns.update({"List": List, "Union": Union, "Tuple": Tuple}) self._annotations = get_type_hints( self.__class__, globalns=globals(), localns=localns @@ -196,7 +212,7 @@ def __hash__(self): return hash(self.uid) def __getattr__(self, attribute): - if attribute in self.__annotations__: + if (not attribute == "__annotations__") and (attribute in self.__annotations__): raise MissingAttribute(self, attribute) super().__getattribute__(attribute) @@ -211,15 +227,14 @@ def __setattr__(self, name: str, value: Any): orig = self.__dict__.get(name, None) self.__dict__[name] = value + # TODO: This causes issues with Node types that want private variables if hasattr(self, "_init") and self._init and not name.startswith("_"): self._dirty.add(name) if name in self._directives and any( isinstance(d, reverse) for d in self._directives[name] ): directive = list( - filter( - lambda d: isinstance(d, reverse), self._directives[name] - ) + filter(lambda d: isinstance(d, reverse), self._directives[name]) )[0] reverse_name = directive.name if directive.name else f"_{name}" @@ -254,26 +269,262 @@ def _assign(obj, key, value, do_many, remove=False): if value is not None: _assign(value, reverse_name, self, directive.many) elif orig: - _assign( - orig, reverse_name, self, directive.many, remove=True + _assign(orig, reverse_name, self, directive.many, remove=True) + + @classmethod + def create(cls, **kwargs) -> Node: + """ + Constructor for creating a node. + """ + instance = cls() + for k, v in kwargs.items(): + setattr(instance, k, v) + return instance + + @classmethod + def _explode( + cls, + instance: Node, + max_depth: Optional[int] = None, + depth: int = 0, + include: List[str] = None, + ) -> Dict[str, Any]: + """ + Explode a Node object into a mapping + """ + # TODO: + # - Candidate for refactoring + # - Should the default max_depth be None? + obj = {"_type": instance.__class__.__name__} + + if not isinstance(instance, Node) and not is_facets(instance): + if is_facets(instance): + return instance._asdict() + raise Exception("Cannot explode a non-Node object") + + if is_facets(instance): + data = list(instance._asdict().items()) + else: + data = list(instance.__dict__.items()) + if include: + for prop in include: + data.append((prop, getattr(instance, prop, None))) + + data = filter(lambda x: x[1] is not None, data) + + annotations = ( + instance.obj._annotations if is_facets(instance) else instance._annotations + ) + for key, value in data: + if ( + is_facets(instance) + or key in annotations.keys() + or key == "uid" + or (include and key in include) + ): + if isinstance(value, (str, int, float, bool)): + obj[key] = value + elif issubclass(value.__class__, Node): + explode = depth < max_depth if max_depth is not None else True + if explode: + obj[key] = cls._explode( + value, depth=(depth + 1), max_depth=max_depth + ) + else: + obj[key] = str(value) + elif isinstance(value, (list,)): + explode = depth < max_depth if max_depth is not None else True + if explode: + obj[key] = [ + cls._explode(x, depth=(depth + 1), max_depth=max_depth) + for x in value + ] + else: + obj[key] = str(value) + elif is_computed(value): + obj.update({key: value._asdict()}) + elif is_facets(value): + prop_type = annotations[key] + is_list_type = ( + True + if isinstance(prop_type, _GenericAlias) + and prop_type.__origin__ in (list, tuple) + else False ) + if is_list_type: + if key not in obj: + obj[key] = [] + obj[key].append(value._asdict()) + else: + item = value._asdict() + item.update({"is_facets": True}) + obj[key] = value._asdict() + return obj + + def to_json(self, include: List[str] = None) -> Dict[str, Any]: + # TODO: + # - Should this be renamed? It is a little misleading. Perhaps to_dict() + # would make more sense. + return self.__class__._explode(self, include=include) + + def stage(self) -> None: + """ + Identify a node instance that it is primed and ready to be migrated + """ + self.edges = {} + + for arg, _ in self._annotations.items(): + if not arg.startswith("_") and arg != "uid": + val = getattr(self, arg, None) + if val is not None: + self.edges[arg] = val + self._staged[self.uid] = self + + def save( + self, client: PyDiggyClient = None, host: str = None, port: int = None + ) -> None: + # TODO: + # - User self._annotations + localns = {x.__name__: x for x in Node._nodes} + localns.update({"List": List, "Union": Union, "Tuple": Tuple}) + annotations = get_type_hints(self, globalns=globals(), localns=localns) + + if client is None: + client = get_client(host=host, port=9080) + + def _make_obj(node, pred, obj): + # TODO: Remove this in favor of the _make_obj in operations + annotation = annotations.get(pred, "") + if hasattr(annotation, "__origin__") and annotation.__origin__ == list: + annotation = annotation.__args__[0] + + try: + if annotation == str: + obj = re.sub('"', '\\"', obj.rstrip()) + obj = f'"{obj}"' + elif annotation == bool: + obj = f'"{str(obj).lower()}"' + elif annotation in (int,): + obj = f'"{int(obj)}"^^' + elif annotation in (float,) or isinstance(obj, float): + obj = f'"{obj}"^^' + elif Node._is_node_type(obj.__class__): + obj, passed = _parse_subject(obj.uid) + staged = Node._get_staged() + + if ( + obj not in staged + and passed not in staged + and not isinstance(passed, int) + ): + raise NotStaged(f"<{node.__class__.__name__} {pred}={obj}>") + except ValueError: + raise ValueError( + f"Incorrect value type. Received <{node.__class__.__name__} {pred}={obj}>. Expecting <{node.__class__.__name__} {pred}={annotation.__name__}>" + ) + + if isinstance(obj, (tuple, set)): + obj = list(obj) + + return obj + + setters = [] + deleters = [] + + saveable = (x for x in self._dirty if x != "computed" and x in annotations) + + for pred in saveable: + obj = getattr(self, pred) + subject, passed = _parse_subject(self.uid) + if not isinstance(obj, list): + obj = [obj] + + for o in obj: + if issubclass(o.__class__, Enum): + o = o.value + + facets = [] + if is_facets(o): + for facet in o.__class__._fields[1:]: + val = _raw_value(getattr(o, facet)) + facets.append(f"{facet}={val}") + o = o.obj + + if not isinstance(o, (list, tuple, set)): + out = [o] + else: + out = o + + for output in out: + if output is None: + line = f"{subject} <{pred}> * ." + deleters.append(line) + continue + + is_node_type = self._is_node_type(output.__class__) + output = _make_obj(self, pred, output) + + # Temporary measure until dgraph 1.1 with 1:1 uid + if is_node_type: + prop_type = annotations[pred] + is_list_type = ( + True + if isinstance(prop_type, _GenericAlias) + and prop_type.__origin__ in (list, tuple) + else False + ) + if not is_list_type: + line = f"{subject} <{pred}> * ." + transaction = client.txn() + try: + transaction.mutate(del_nquads=line) + transaction.commit() + finally: + transaction.discard() + + if facets: + facets = ", ".join(facets) + line = f"{subject} <{pred}> {output} ({facets}) ." + else: + line = f"{subject} <{pred}> {output} ." + setters.append(line) + + set_mutations = "\n".join(setters) + delete_mutations = "\n".join(deleters) + transaction = client.txn() + + try: + if set_mutations or delete_mutations: + o = transaction.mutate( + set_nquads=set_mutations, del_nquads=delete_mutations + ) + transaction.commit() + finally: + transaction.discard() + + +class NodeTypeRegistry: + _i = _count() + _node_types = [] @classmethod - def _reset(cls) -> None: - cls._i = _count() - cls._instances = dict() + def _generate_uid(cls) -> str: + i = next(cls._i) + yield f"unsaved.{i}" @classmethod - def _register_node(cls, node: Node) -> None: - cls._nodes.append(node) + def _reset(cls) -> None: + cls._i = _count() + for node_type in cls._node_types: + node_type._reset() @classmethod - def _get_name(cls) -> str: - return cls.__name__ + def _register_node_type(cls, node_type: type) -> None: + cls._node_types.append(node_type) @classmethod def _generate_schema(cls) -> str: - nodes = cls._nodes + nodes = cls._node_types edges = {} schema = [] type_schema = [] @@ -314,33 +565,24 @@ def _generate_schema(cls) -> str: prop_type = prop_type.__args__[0] prop_type = PropType( - prop_type, - is_list_type, - node._directives.get(prop_name, []), + prop_type, is_list_type, node._directives.get(prop_name, []) ) if prop_name in edges: - if prop_type != edges.get( - prop_name - ) and not cls._is_node_type(prop_type[0]): + if prop_type != edges.get(prop_name) and not cls._is_node_type( + prop_type[0] + ): # Check if there is a type conflict if ( - edges.get(prop_name).directives - != prop_type.directives + edges.get(prop_name).directives != prop_type.directives and all( - ( - inspect.isclass(x) - and issubclass(x, Directive) - ) + (inspect.isclass(x) and issubclass(x, Directive)) or issubclass(x.__class__, Directive) for x in edges.get(prop_name).directives ) and all( - ( - inspect.isclass(x) - and issubclass(x, Directive) - ) + (inspect.isclass(x) and issubclass(x, Directive)) or issubclass(x.__class__, Directive) for x in prop_type.directives ) @@ -356,9 +598,7 @@ def _generate_schema(cls) -> str: edges[prop_name] = prop_type elif cls._is_node_type(prop_type[0]): edges[prop_name] = PropType( - "uid", - is_list_type, - node._directives.get(prop_name, []), + "uid", is_list_type, node._directives.get(prop_name, []) ) else: if prop_name != "uid": @@ -408,17 +648,19 @@ def _get_type_name(cls, schema_type): @classmethod def _get_staged(cls): - return cls._staged + all_staged = [node_type._get_staged() for node_type in cls._node_types] + return dict(ChainMap(*all_staged)) @classmethod def _clear_staged(cls): - cls._staged = {} + for node_type in cls._node_types: + node_type._clear_staged() cls._i = _count() @classmethod def _hydrate( cls, raw: Dict[str, Any], types: Dict[str, Node] = None - ) -> Node: + ) -> Node: # -> Dict[str: List[Node]] # TODO: # - Accept types that are passed. Loop thru them and register if needed # and raising an exception if they are not valid. @@ -426,8 +668,8 @@ def _hydrate( # complexity. # - Should create a Facets type so that the type annotation of this function # is _hydrate(cls, raw: str, types: Dict[str, Node] = None) -> Union[Node, Facets] - registered = {x.__name__: x for x in Node._nodes} - localns = {x.__name__: x for x in Node._nodes} + registered = {x.__name__: x for x in NodeTypeRegistry._node_types} + localns = {x.__name__: x for x in NodeTypeRegistry._node_types} localns.update({"List": List, "Union": Union, "Tuple": Tuple}) if "_type" in raw and raw.get("_type") in registered: @@ -446,14 +688,10 @@ def _hydrate( computed = {} pred_items = [ - (pred, value) - for pred, value in raw.items() - if not pred.startswith("_") + (pred, value) for pred, value in raw.items() if not pred.startswith("_") ] - annotations = get_type_hints( - k, globalns=globals(), localns=localns - ) + annotations = get_type_hints(k, globalns=globals(), localns=localns) for pred, value in pred_items: """ The pred falls into one of three categories: @@ -483,8 +721,8 @@ def _hydrate( # Will probably need to be revisited when # Dgraph v. 1.1 is released raise Exception("Unknown data") - node = get_node(value[0].get("_type")) - value = node._hydrate(value[0]) + node_type = get_node_type(value[0].get("_type")) + value = node_type._hydrate(value[0]) elif isinstance(value, dict): value = cls._hydrate(value) @@ -496,22 +734,16 @@ def _hydrate( for x in value: keys = deepcopy(list(x.keys())) value_facet_data = [ - (k.split("|")[1], x.pop(k)) - for k in keys - if "|" in k + (k.split("|")[1], x.pop(k)) for k in keys if "|" in k ] - item = get_node(x.get("_type"))._hydrate(x) + item = NodeTypeRegistry._hydrate(x) if value_facet_data: item = Facets(item, **dict(value_facet_data)) delay.append((item, p, value_facet_data)) elif isinstance(value, dict): delay.append( - ( - get_node(value.get("_type"))._hydrate(value), - p, - None, - ) + (get_node(value.get("_type"))._hydrate(value), p, None) ) else: if pred.endswith("_uid"): @@ -545,264 +777,16 @@ def json(cls) -> Dict[str, List[Node]]: """ # TODO: # - Probably should be renamed - # - Instrad of being a List[Node], it should probably be a Set + # - Instead of being a List[Node], it should probably be a Set return { x.__name__: list( - map(partial(cls._explode, max_depth=0), x._instances.values()) + map(partial(x._explode, max_depth=0), x._instances.values()) ) - for x in cls._nodes + for x in cls._node_types if len(x._instances) > 0 } - @classmethod - def _explode( - cls, - instance: Node, - max_depth: Optional[int] = None, - depth: int = 0, - include: List[str] = None, - ) -> Dict[str, Any]: - """ - Explode a Node object into a mapping - """ - # TODO: - # - Candidate for refactoring - # - Should the default max_depth be None? - obj = {"_type": instance.__class__.__name__} - - if not isinstance(instance, Node) and not is_facets(instance): - if is_facets(instance): - return instance._asdict() - raise Exception("Cannot explode a non-Node object") - - if is_facets(instance): - data = list(instance._asdict().items()) - else: - data = list(instance.__dict__.items()) - if include: - for prop in include: - data.append((prop, getattr(instance, prop, None))) - - data = filter(lambda x: x[1] is not None, data) - - annotations = ( - instance.obj._annotations - if is_facets(instance) - else instance._annotations - ) - for key, value in data: - if ( - is_facets(instance) - or key in annotations.keys() - or key == "uid" - or (include and key in include) - ): - if isinstance(value, (str, int, float, bool)): - obj[key] = value - elif issubclass(value.__class__, Node): - explode = ( - depth < max_depth if max_depth is not None else True - ) - if explode: - obj[key] = cls._explode( - value, depth=(depth + 1), max_depth=max_depth - ) - else: - obj[key] = str(value) - elif isinstance(value, (list,)): - explode = ( - depth < max_depth if max_depth is not None else True - ) - if explode: - obj[key] = [ - cls._explode( - x, depth=(depth + 1), max_depth=max_depth - ) - for x in value - ] - else: - obj[key] = str(value) - elif is_computed(value): - obj.update({key: value._asdict()}) - elif is_facets(value): - prop_type = annotations[key] - is_list_type = ( - True - if isinstance(prop_type, _GenericAlias) - and prop_type.__origin__ in (list, tuple) - else False - ) - if is_list_type: - if key not in obj: - obj[key] = [] - obj[key].append(value._asdict()) - else: - item = value._asdict() - item.update({"is_facets": True}) - obj[key] = value._asdict() - return obj - - @classmethod - def create(cls, **kwargs) -> Node: - """ - Constructor for creating a node. - """ - instance = cls() - for k, v in kwargs.items(): - setattr(instance, k, v) - return instance - @staticmethod def _is_node_type(cls) -> bool: """Check if a class is a """ return inspect.isclass(cls) and issubclass(cls, Node) - - def _generate_uid(self) -> str: - i = next(self._i) - yield f"unsaved.{i}" - - def to_json(self, include: List[str] = None) -> Dict[str, Any]: - # TODO: - # - Should this be renamed? It is a little misleading. Perhaps to_dict() - # would make more sense. - return self.__class__._explode(self, include=include) - - def stage(self) -> None: - """ - Identify a node instance that it is primed and ready to be migrated - """ - self.edges = {} - - for arg, _ in self._annotations.items(): - if not arg.startswith("_") and arg != "uid": - val = getattr(self, arg, None) - if val is not None: - self.edges[arg] = val - self._staged[self.uid] = self - - def save( - self, client: PyDiggyClient = None, host: str = None, port: int = None - ) -> None: - # TODO: - # - User self._annotations - localns = {x.__name__: x for x in Node._nodes} - localns.update({"List": List, "Union": Union, "Tuple": Tuple}) - annotations = get_type_hints(self, globalns=globals(), localns=localns) - - if client is None: - client = get_client(host=host, port=9080) - - def _make_obj(node, pred, obj): - annotation = annotations.get(pred, "") - if ( - hasattr(annotation, "__origin__") - and annotation.__origin__ == list - ): - annotation = annotation.__args__[0] - - try: - if annotation == str: - obj = f'"{obj}"' - elif annotation == bool: - obj = f'"{str(obj).lower()}"' - elif annotation in (int,): - obj = f'"{int(obj)}"^^' - elif annotation in (float,) or isinstance(obj, float): - obj = f'"{obj}"^^' - elif Node._is_node_type(obj.__class__): - obj, passed = _parse_subject(obj.uid) - staged = Node._get_staged() - - if ( - obj not in staged - and passed not in staged - and not isinstance(passed, int) - ): - raise NotStaged( - f"<{node.__class__.__name__} {pred}={obj}>" - ) - except ValueError: - raise ValueError( - f"Incorrect value type. Received <{node.__class__.__name__} {pred}={obj}>. Expecting <{node.__class__.__name__} {pred}={annotation.__name__}>" - ) - - if isinstance(obj, (tuple, set)): - obj = list(obj) - - return obj - - setters = [] - deleters = [] - - saveable = ( - x for x in self._dirty if x != "computed" and x in annotations - ) - - for pred in saveable: - obj = getattr(self, pred) - subject, passed = _parse_subject(self.uid) - if not isinstance(obj, list): - obj = [obj] - - for o in obj: - if issubclass(o.__class__, Enum): - o = o.value - - facets = [] - if is_facets(o): - for facet in o.__class__._fields[1:]: - val = _raw_value(getattr(o, facet)) - facets.append(f"{facet}={val}") - o = o.obj - - if not isinstance(o, (list, tuple, set)): - out = [o] - else: - out = o - - for output in out: - if output is None: - line = f"{subject} <{pred}> * ." - deleters.append(line) - continue - - is_node_type = self._is_node_type(output.__class__) - output = _make_obj(self, pred, output) - - # Temporary measure until dgraph 1.1 with 1:1 uid - if is_node_type: - prop_type = annotations[pred] - is_list_type = ( - True - if isinstance(prop_type, _GenericAlias) - and prop_type.__origin__ in (list, tuple) - else False - ) - if not is_list_type: - line = f"{subject} <{pred}> * ." - transaction = client.txn() - try: - transaction.mutate(del_nquads=line) - transaction.commit() - finally: - transaction.discard() - - if facets: - facets = ", ".join(facets) - line = f"{subject} <{pred}> {output} ({facets}) ." - else: - line = f"{subject} <{pred}> {output} ." - setters.append(line) - - set_mutations = "\n".join(setters) - delete_mutations = "\n".join(deleters) - transaction = client.txn() - - try: - if set_mutations or delete_mutations: - o = transaction.mutate( - set_nquads=set_mutations, del_nquads=delete_mutations - ) - transaction.commit() - finally: - transaction.discard() diff --git a/pydiggy/operations.py b/pydiggy/operations.py index bc33e10..9db8244 100644 --- a/pydiggy/operations.py +++ b/pydiggy/operations.py @@ -1,20 +1,18 @@ import json as _json +import re from datetime import datetime from enum import Enum from typing import List, Tuple, Union, get_type_hints, Dict, Any from pydiggy.connection import get_client, PyDiggyClient from pydiggy.exceptions import NotStaged -from pydiggy.node import Node +from pydiggy.node import Node, NodeTypeRegistry from pydiggy._types import * # noqa from pydiggy.utils import _parse_subject, _raw_value def _make_obj(node, pred, obj): - localns = {x.__name__: x for x in Node._nodes} - localns.update({"List": List, "Union": Union, "Tuple": Tuple}) - annotations = get_type_hints(node, globalns=globals(), localns=localns) - annotation = annotations.get(pred, "") + annotation = node._annotations.get(pred, "") if hasattr(annotation, "__origin__") and annotation.__origin__ == list: annotation = annotation.__args__[0] @@ -24,9 +22,9 @@ def _make_obj(node, pred, obj): # TODO: # - integreate utils._rdf_value try: - if Node._is_node_type(obj.__class__): + if NodeTypeRegistry._is_node_type(obj.__class__): uid, passed = _parse_subject(obj.uid) - staged = Node._get_staged() + staged = NodeTypeRegistry._get_staged() if ( uid not in staged @@ -51,6 +49,7 @@ def _make_obj(node, pred, obj): elif isinstance(obj, datetime): obj = f'"{obj.isoformat()}"' else: + obj = re.sub('"', '\\"', obj.rstrip()) obj = f'"{obj}"' except ValueError: raise ValueError( @@ -129,14 +128,14 @@ def hydrate(data: str, types: List[Node] = None) -> Dict[str, List[Node]]: output = {} # data = data.get(data_set) - registered = {x.__name__: x for x in Node._nodes} + registered = {x.__name__: x for x in NodeTypeRegistry._node_types} for func_name, raw_data in data.items(): hydrated = [] for raw in raw_data: if "_type" in raw and raw.get("_type") in registered: cls = registered.get(raw.get("_type")) - hydrated.append(cls._hydrate(raw, types=types)) + hydrated.append(NodeTypeRegistry._hydrate(raw, types=types)) output[func_name] = hydrated diff --git a/pydiggy/utils.py b/pydiggy/utils.py index 38f5f32..4de1953 100644 --- a/pydiggy/utils.py +++ b/pydiggy/utils.py @@ -1,3 +1,6 @@ +import re + + def _parse_subject(uid): if isinstance(uid, int): return f"<{hex(uid)}>", uid @@ -6,7 +9,11 @@ def _parse_subject(uid): def _rdf_value(value): + """ + Translates a value into a string annotated with an RDF type + """ if isinstance(value, str): + value = re.sub('"', '\\"', value.rstrip()) value = f'"{value}"' elif isinstance(value, bool): value = f'"{str(value).lower()}"' diff --git a/tests/test_mutation.py b/tests/test_mutation.py index 78f7183..f80c9d9 100644 --- a/tests/test_mutation.py +++ b/tests/test_mutation.py @@ -1,12 +1,10 @@ -from pydiggy import Facets, generate_mutation - -# import pytest +from pydiggy import Facets, generate_mutation, NodeTypeRegistry def test_mutations(RegionClass): Region = RegionClass - Region._reset() + NodeTypeRegistry._reset() por = Region(uid=0x11, name="Portugal") spa = Region(uid=0x12, name="Spain") @@ -54,3 +52,21 @@ def test_mutations(RegionClass): pprint.pprint(mutation) assert control == mutation + + +def test__mutation__with__quotes(RegionClass): + Region = RegionClass + + NodeTypeRegistry._reset() + + florida = Region(name="Florida 'The \"Sunshine\" State'") + + florida.stage() + + mutation = generate_mutation() + + control = """_:unsaved.0 "true" . +_:unsaved.0 <_type> "Region" . +_:unsaved.0 "Florida 'The \\"Sunshine\\" State'" .""" + + assert mutation == control diff --git a/tests/test_node.py b/tests/test_node.py index 0f244c5..316ca95 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -4,13 +4,13 @@ # import pytest from pprint import pprint as print -from pydiggy import Facets, Node +from pydiggy import Facets, Node, NodeTypeRegistry def test__node__to__json(RegionClass): Region = RegionClass - Region._reset() + NodeTypeRegistry._reset() por = Region(uid=0x11, name="Portugal") spa = Region(uid=0x12, name="Spain") @@ -22,7 +22,7 @@ def test__node__to__json(RegionClass): gas.borders = [Facets(spa, foo="bar", hello="world"), mar] mar.borders = [spa, gas] - regions = Node.json().get("Region") + regions = NodeTypeRegistry.json().get("Region") control = [ {"_type": "Region", "borders": "[]", "name": "Portugal", "uid": 17}, @@ -48,3 +48,23 @@ def test__node__to__json(RegionClass): ] assert regions == control + + +def test__node__with__quotes(RegionClass): + Region = RegionClass + + NodeTypeRegistry._reset() + + florida = Region(name="Florida 'The \"Sunshine\" State'") + + regions = NodeTypeRegistry.json().get("Region") + + control = [ + { + "_type": "Region", + "name": "Florida 'The \"Sunshine\" State'", + "uid": "unsaved.0", + } + ] + + assert regions == control diff --git a/tests/test_operations.py b/tests/test_operations.py index e47bbf3..21d02dd 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -1,4 +1,4 @@ -from pydiggy import operations +from pydiggy import operations, NodeTypeRegistry def test__parse_subject(): @@ -8,12 +8,12 @@ def test__parse_subject(): subject = operations._parse_subject(123) assert subject == ("<0x7b>", 123) - subject = operations._parse_subject(0x7b) + subject = operations._parse_subject(0x7B) assert subject == ("<0x7b>", 123) def test__make_obj(TypeTestClass): - TypeTestClass._reset() + NodeTypeRegistry._reset() node = TypeTestClass() node.stage() diff --git a/tests/test_pydiggy.py b/tests/test_pydiggy.py index 0675f83..9f17125 100644 --- a/tests/test_pydiggy.py +++ b/tests/test_pydiggy.py @@ -6,7 +6,7 @@ import pytest from click.testing import CliRunner -from pydiggy import Node, cli +from pydiggy import Node, cli, NodeTypeRegistry @pytest.fixture @@ -24,7 +24,7 @@ def test_command_line_interface_has_commands(runner, commands): def test_dry_run_generate_schema(runner): - Node._nodes = [] + NodeTypeRegistry._node_types = [] result = runner.invoke(cli.main, ["generate", "tests.fakeapp", "--no-run"]) assert result.exit_code == 0 assert "Nodes found: (1)" in result.output