Skip to content

Commit

Permalink
support group_by with join (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice authored Apr 30, 2020
1 parent 1569649 commit d53516d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changelog
-------
- Fix bad import of ``basestring``
- Better handling of NULL characters in strings. Fixes SQLite, raises better error for PostgreSQL.
- Support ``.group_by()`` with join now

0.16.9
------
Expand Down
9 changes: 9 additions & 0 deletions examples/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ async def run():
print(ret)
# >>> [(1, 10), (2, 5)]

# group by with join
ret = (
await Book.annotate(count=Count("id"))
.group_by("author__name")
.values("author__name", "count")
)
print(ret)
# >>> [{"author__name": "author1", "count": 10}, {"author__name": "author2", "count": 5}]


if __name__ == "__main__":
run_async(run())
55 changes: 55 additions & 0 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ async def test_count_group_by(self):
elif author_id == self.a2.pk:
self.assertEqual(count, 5)

async def test_count_group_by_with_join(self):
ret = (
await Book.annotate(count=Count("id"))
.group_by("author__name")
.values("author__name", "count")
)
self.assertEqual(
ret, [{"author__name": "author1", "count": 10}, {"author__name": "author2", "count": 5}]
)

async def test_count_filter_group_by(self):
ret = (
await Book.annotate(count=Count("id"))
Expand All @@ -49,6 +59,17 @@ async def test_sum_group_by(self):
elif author_id == self.a2.pk:
self.assertEqual(sum_, 10.0)

async def test_sum_group_by_with_join(self):
ret = (
await Book.annotate(sum=Sum("rating"))
.group_by("author__name")
.values("author__name", "sum")
)
self.assertEqual(
ret,
[{"author__name": "author1", "sum": 45.0}, {"author__name": "author2", "sum": 10.0}],
)

async def test_sum_filter_group_by(self):
ret = (
await Book.annotate(sum=Sum("rating"))
Expand All @@ -72,6 +93,16 @@ async def test_avg_group_by(self):
elif author_id == self.a2.pk:
self.assertEqual(avg, 2.0)

async def test_avg_group_by_with_join(self):
ret = (
await Book.annotate(avg=Avg("rating"))
.group_by("author__name")
.values("author__name", "avg")
)
self.assertEqual(
ret, [{"author__name": "author1", "avg": 4.5}, {"author__name": "author2", "avg": 2}]
)

async def test_avg_filter_group_by(self):
ret = (
await Book.annotate(avg=Avg("rating"))
Expand All @@ -97,6 +128,14 @@ async def test_count_values_list_group_by(self):
elif author_id == self.a2.pk:
self.assertEqual(count, 5)

async def test_count_values_list_group_by_with_join(self):
ret = (
await Book.annotate(count=Count("id"))
.group_by("author__name")
.values_list("author__name", "count")
)
self.assertEqual(ret, [("author1", 10), ("author2", 5)])

async def test_count_values_list_filter_group_by(self):
ret = (
await Book.annotate(count=Count("id"))
Expand All @@ -121,6 +160,14 @@ async def test_sum_values_list_group_by(self):
elif author_id == self.a2.pk:
self.assertEqual(sum_, 10.0)

async def test_sum_values_list_group_by_with_join(self):
ret = (
await Book.annotate(sum=Sum("rating"))
.group_by("author__name")
.values_list("author__name", "sum")
)
self.assertEqual(ret, [("author1", 45.0), ("author2", 10.0)])

async def test_sum_values_list_filter_group_by(self):
ret = (
await Book.annotate(sum=Sum("rating"))
Expand All @@ -146,6 +193,14 @@ async def test_avg_values_list_group_by(self):
elif author_id == self.a2.pk:
self.assertEqual(avg, 2.0)

async def test_avg_values_list_group_by_with_join(self):
ret = (
await Book.annotate(avg=Avg("rating"))
.group_by("author__name")
.values_list("author__name", "avg")
)
self.assertEqual(ret, [("author1", 4.5), ("author2", 2.0)])

async def test_avg_values_list_filter_group_by(self):
ret = (
await Book.annotate(avg=Avg("rating"))
Expand Down
20 changes: 16 additions & 4 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,20 @@ def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable:

raise FieldError(f'Unknown field "{field}" for model "{model}"')

def _resolve_group_bys(self, *field_names: str):
group_bys = []
for field_name in field_names:
field_split = field_name.split("__")
related_table, related_db_field = self._join_table_with_forwarded_fields(
model=self.model,
table=self.model._meta.basetable,
field=field_split[0],
forwarded_fields="__".join(field_split[1:]) if len(field_split) > 1 else "",
)
field = related_table[related_db_field].as_(field_name)
group_bys.append(field)
return group_bys


class ValuesListQuery(FieldSelectQuery):
__slots__ = (
Expand Down Expand Up @@ -1017,8 +1031,7 @@ def _make_query(self) -> None:
if self.distinct:
self.query._distinct = True
if self.group_bys:
self.query._groupbys = []
self.query = self.query.groupby(*self.group_bys)
self.query._groupbys = self._resolve_group_bys(*self.group_bys)

def __await__(self) -> Generator[Any, None, List[Any]]:
if self._db is None:
Expand Down Expand Up @@ -1104,8 +1117,7 @@ def _make_query(self) -> None:
if self.distinct:
self.query._distinct = True
if self.group_bys:
self.query._groupbys = []
self.query = self.query.groupby(*self.group_bys)
self.query._groupbys = self._resolve_group_bys(*self.group_bys)

def __await__(self) -> Generator[Any, None, List[dict]]:
if self._db is None:
Expand Down

0 comments on commit d53516d

Please sign in to comment.