diff --git a/.gitignore b/.gitignore index 7da223a7..4e4e89d2 100644 --- a/.gitignore +++ b/.gitignore @@ -110,6 +110,9 @@ ENV/ env.bak/ venv.bak/ +# Jetbrains IDE +.idea/ + # Spyder project settings .spyderproject .spyproject diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 92bb6f9a..4eb34f84 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -760,6 +760,31 @@ async def all(self, batch_size=10): return await query.execute() return await self.execute() + async def count(self) -> int: + args = [ + "FT.AGGREGATE", + self.model.Meta.index_name, + self.query, + "APPLY", + "matched_terms()", + "AS", + "countable", + "GROUPBY", + "1", + "@countable", + "REDUCE", + "COUNT", + "0", + ] + raw_result = await self.model.db().execute_command(*args) + print(raw_result, args) + try: + return sum( + [int(result[3].decode("utf-8", "ignore")) for result in raw_result[1:]] + ) + except IndexError: + return 0 + def sort_by(self, *fields: str): if not fields: return self @@ -792,7 +817,10 @@ async def update(self, use_transaction=True, **field_values): async def delete(self): """Delete all matching records in this query.""" # TODO: Better response type, error detection - return await self.model.db().delete(*[m.key() for m in await self.all()]) + keys_to_delete = [m.key() for m in await self.all()] + if not keys_to_delete: + return 0 + return await self.model.db().delete(*keys_to_delete) async def __aiter__(self): if self._model_cache: diff --git a/tests/conftest.py b/tests/conftest.py index 9f067a38..77685361 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,7 +47,7 @@ def key_prefix(request, redis): def cleanup_keys(request): # Always use the sync Redis connection with finalizer. Setting up an # async finalizer should work, but I'm not suer how yet! - from redis_om.connections import get_redis_connection as get_sync_redis + from aredis_om.connections import get_redis_connection as get_sync_redis # Increment for every pytest-xdist worker conn = get_sync_redis() diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 0a79aa6b..baeb5ed2 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -22,7 +22,7 @@ # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. -from redis_om import has_redisearch +from aredis_om import has_redisearch from tests.conftest import py_test_mark_asyncio if not has_redisearch(): @@ -96,6 +96,16 @@ async def members(m): yield member1, member2, member3 +@py_test_mark_asyncio +async def test_all_query(members, m): + + actual = await m.Member.find().all() + assert all([member in actual for member in members]) + + actual_count = await m.Member.find().count() + assert actual_count == len(members) + + @py_test_mark_asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members @@ -129,6 +139,11 @@ async def test_exact_match_queries(members, m): ).all() assert actual == [member2] + actual_count = await m.Member.find( + m.Member.first_name == "Kim", m.Member.last_name == "Brookins" + ).count() + assert actual_count == 1 + @py_test_mark_asyncio async def test_full_text_search_queries(members, m): @@ -162,16 +177,17 @@ async def test_recursive_query_resolution(members, m): async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members - actual = await ( - m.Member.find( - (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") - | (m.Member.last_name == "Smith") - ) - .sort_by("age") - .all() + find_query = m.Member.find( + (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") + | (m.Member.last_name == "Smith") ) + + actual = await find_query.sort_by("age").all() assert actual == [member1, member3] + actual_count = await find_query.count() + assert actual_count == 2 + @py_test_mark_asyncio async def test_tag_queries_punctuation(m): diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8a114f9a..b3a3ba3d 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -24,7 +24,7 @@ # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. -from redis_om import has_redis_json +from aredis_om import has_redis_json from tests.conftest import py_test_mark_asyncio if not has_redis_json(): @@ -291,7 +291,7 @@ async def test_saves_many_explicit_transaction(address, m): async with m.Member.db().pipeline(transaction=True) as pipeline: await m.Member.add(members, pipeline=pipeline) assert result == [member1, member2] - assert await pipeline.execute() == ["OK", "OK"] + assert await pipeline.execute() == [b"OK", b"OK"] assert await m.Member.get(pk=member1.pk) == member1 assert await m.Member.get(pk=member2.pk) == member2