Skip to content

Commit

Permalink
feat: avoid using class attributes to prevent incosistent state
Browse files Browse the repository at this point in the history
  • Loading branch information
Ale-Cas committed Jan 31, 2024
1 parent d053a84 commit ab36887
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 71 deletions.
98 changes: 43 additions & 55 deletions src/querytyper/mongo/query.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
"""MongoQuery implementation."""
import re
from collections.abc import Iterable
from collections import UserDict
from typing import Any, Dict, Generic, Type, TypeVar, Union, cast

T = TypeVar("T")
DictStrAny = Dict[str, Any]

class _BaseQuery(UserDict):

class MongoQuery(DictStrAny):
def __init__(self, *args: Any) -> None:
"""
Initialize a query object.
"""
if not len(args) == 1:
raise TypeError(
f"The initializer takes 1 positional argument but {len(args)} were given."
)
arg = args[0]
if not isinstance(arg, (dict, bool)):
raise TypeError(
f"The initializer argument must be a dictionary like object, {type(arg)} is not supported."
)
if isinstance(arg, dict):
super().__init__(arg)
else:
arg

@property # type: ignore[misc]
def __class__(self) -> Type[dict]: # type: ignore[override]
"""Return true if isinstance(self, dict)."""
return dict

class MongoQuery(_BaseQuery):
"""
MongoQuery is the core `querytyper` class to write MongoDB queries.
Expand All @@ -24,68 +49,24 @@ class MongoQuery(DictStrAny):
```
"""

_query_dict: DictStrAny = {}

def __init__(self, *args: Any, **kwargs: DictStrAny) -> None:
"""
Initialize a query object.
"""
for arg in args:
if not isinstance(arg, (QueryCondition, dict, bool)):
raise TypeError(
f"MongoQuery argument must be a QueryCondition, dict or a boolean value, {type(arg)} is not supported."
)
if isinstance(arg, QueryCondition):
MongoQuery._query_dict.update(arg)
super().__init__(MongoQuery._query_dict)
# clean up the class query dict after each instantiation
MongoQuery._query_dict = {}

def __del__(self) -> None:
"""MongoQuery destructor."""
MongoQuery._query_dict = {}

def __or__(
self,
other: "MongoQuery",
) -> "MongoQuery":
"""Overload | operator."""
MongoQuery._query_dict = {"$or": [self, other]}
return MongoQuery()
return MongoQuery({"$or": [self, other]})


class QueryCondition(DictStrAny):
class QueryCondition(_BaseQuery):
"""Class to represent a single query condition."""

def __init__(self, *args: Any, **kwargs: DictStrAny) -> None:
"""
Initialize a QueryCondition instance.
It should receive a dict as only argument.
Example
-------
```python
QueryCondition({"field": "value"})
```
It also overloads dict __init__ typing.
"""
arg = args[0]
if len(args) != 1 or not isinstance(arg, dict):
raise TypeError("QueryCondition must receive only one dict as input.")
if isinstance(arg, dict):
super().__init__(**arg)
for k, v in arg.items():
self.__setitem__(k, v)

def __and__(
self,
other: Union["QueryCondition", bool],
) -> "QueryCondition":
"""Overload & operator."""
if isinstance(other, QueryCondition):
MongoQuery._query_dict.update(other)
self.update(other)
return self

def __rand__(
Expand All @@ -94,7 +75,7 @@ def __rand__(
) -> "QueryCondition":
"""Overload & operator."""
if isinstance(other, QueryCondition):
MongoQuery._query_dict.update(other)
self.update(other)
return self

def __bool__(self) -> bool:
Expand All @@ -113,6 +94,7 @@ def __init__(
"""Initialize QueryField instance."""
self.name = name
self.field_type = field_type
self._query_dict: DictStrAny = {}

def __get__(
self,
Expand All @@ -127,17 +109,16 @@ def __eq__( # type: ignore[override]
other: object,
) -> QueryCondition:
"""Overload == operator."""
_query_dict = MongoQuery._query_dict
field = _query_dict.get(self.name)
field = self._query_dict.get(self.name)
if field is None:
_query_dict[self.name] = other
self._query_dict[self.name] = other
else:
_query_dict[self.name] = (
self._query_dict[self.name] = (
[*field, other]
if isinstance(field, Iterable) and not isinstance(field, str)
else [field, other]
)
return QueryCondition(_query_dict)
return QueryCondition(self._query_dict)

def __gt__(
self,
Expand Down Expand Up @@ -172,7 +153,14 @@ def __contains__(
other: T,
) -> QueryCondition:
"""Overload in operator."""
return regex_query(self.name, re.compile(other))
if not issubclass(cast(type, self.field_type), str):
raise TypeError(
f"Cannot check if field {self.name} contains {other} because {self.name} is not a subclass of str but {self.field_type}"
)
if not isinstance(other, str):
raise ValueError("Comparison value must be a valid string.")
return self == {"$regex": other}
# return False


def exists(
Expand Down
13 changes: 8 additions & 5 deletions tests/mongo/test_pymongo_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ def test_integration_with_pymongo() -> None:
for i in range(doc_num)
]
)
found_doc = collection.find_one(MongoQuery(QueryModel.int_field == 1))
query = MongoQuery(QueryModel.int_field == 1)
assert isinstance(query, dict)
assert query
found_doc = collection.find_one(query)
assert found_doc is not None
found_dummy = Dummy(**found_doc)
assert found_dummy.int_field == 1
query = MongoQuery("test" in QueryModel.str_field)
assert isinstance(query, dict)
# query = MongoQuery("test" in QueryModel.str_field)
# assert isinstance(query, dict)
# assert query
found_docs = list(collection.find(query))
assert len(found_docs) == doc_num
# found_docs = list(collection.find(query))
# assert len(found_docs) == doc_num
33 changes: 22 additions & 11 deletions tests/mongo/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ class TestNoBaseModel(str, metaclass=MongoFilterMeta):
"""Test class."""


def test_regex_query() -> None:
"""Test regex query."""
condition = regex_query(QueryModel.str_field, re.compile("^a"))
assert isinstance(condition, QueryCondition)
assert condition == {"str_field": {"$regex": "^a"}}
condition = regex_query("str_field", re.compile("^a"))
assert isinstance(condition, QueryCondition)
assert condition == {"str_field": {"$regex": "^a"}}
# def test_regex_query() -> None:
# """Test regex query."""
# condition = regex_query(QueryModel.str_field, re.compile("^a"))
# assert isinstance(condition, QueryCondition)
# assert condition == {"str_field": {"$regex": "^a"}}
# query = MongoQuery(condition)
# assert isinstance(query, MongoQuery)
# assert query == {"str_field": {"$regex": "^a"}}


def test_exists_query() -> None:
Expand All @@ -183,9 +183,20 @@ def test_exists_query() -> None:

def test_query_condition_init() -> None:
"""Test QueryCondition initializer and TypeErrors."""
with pytest.raises(TypeError, match="QueryCondition must receive only one dict as input."):
QueryCondition(1)
QueryCondition(1, 2, 3)
condition = QueryCondition({"field": "value"})
assert "field" in condition
assert condition["field"] == "value"
assert condition == {"field": "value"}

def test_query_contains() -> None:
"""Test query contains override."""
query = MongoQuery("test" in QueryModel.str_field)
assert isinstance(query, MongoQuery)
assert query == {"str_field": {"$regex": "test"}}
with pytest.raises(
TypeError,
match="Cannot check if field int_field contains 1 because int_field is not a subclass of str but <class 'int'>",
):
MongoQuery(1 in QueryModel.int_field) # type: ignore[operator]
with pytest.raises(ValueError, match="Comparison value must be a valid string."):
MongoQuery(1 in QueryModel.str_field) # type: ignore[operator]

0 comments on commit ab36887

Please sign in to comment.