From 856922bd5452cdf00e4e63f3e8ba9ee0db7f373a Mon Sep 17 00:00:00 2001 From: Bruno Alla Date: Thu, 10 Oct 2024 16:49:05 +0100 Subject: [PATCH 1/2] Increase test coverage for mixin and renderer --- tests/test_viewset_mixin.py | 47 ++++++++++++++++++++++++++++++++++-- tests/testapp/models.py | 23 ++++++++++++++++++ tests/testapp/serializers.py | 18 +++++++++++++- tests/testapp/views.py | 11 +++++++-- tests/urls.py | 3 ++- tox.ini | 1 + 6 files changed, 97 insertions(+), 6 deletions(-) diff --git a/tests/test_viewset_mixin.py b/tests/test_viewset_mixin.py index bc15090..ad1553f 100644 --- a/tests/test_viewset_mixin.py +++ b/tests/test_viewset_mixin.py @@ -1,10 +1,15 @@ import io +import datetime as dt import pytest from openpyxl.reader.excel import load_workbook from rest_framework.test import APIClient +from time_machine import TimeMachineFixture -from tests.testapp.models import ExampleModel +from tests.testapp.models import ExampleModel, AllFieldsModel, Tag + + +pytestmark = pytest.mark.django_db @pytest.fixture @@ -12,7 +17,6 @@ def api_client(): return APIClient() -@pytest.mark.django_db def test_simple_viewset_model(api_client): ExampleModel.objects.create(title="test 1", description="This is a test") ExampleModel.objects.create(title="test 2", description="Another test") @@ -53,3 +57,42 @@ def test_simple_viewset_model(api_client): assert len(r3) == 2 assert r3[0].value == "test 3" assert r3[1].value == "Testing this out" + + +def test_all_fields_viewset(api_client, time_machine: TimeMachineFixture): + time_machine.move_to(dt.datetime(2023, 9, 10, 15, 44, 37)) + instance = AllFieldsModel.objects.create(title="Hello", age=36, is_active=True) + instance.tags.set( + [ + Tag.objects.create(name="test"), + Tag.objects.create(name="example"), + ] + ) + response = api_client.get("/all-fields/") + assert response.status_code == 200 + + workbook_buffer = io.BytesIO(response.content) + workbook = load_workbook(workbook_buffer, read_only=True) + sheet = workbook.worksheets[0] + rows = list(sheet.rows) + assert len(rows) == 2 + r0, r1 = rows + + assert [col.value for col in r0] == [ + "title", + "created_at", + "updated_date", + "updated_time", + "age", + "is_active", + "tags", + ] + assert [col.value for col in r1] == [ + "Hello", + dt.datetime(2023, 9, 10, 15, 44, 37), + dt.datetime(2023, 9, 10, 0, 0), + dt.time(15, 44, 37), + 36, + True, + "test, example", + ] diff --git a/tests/testapp/models.py b/tests/testapp/models.py index ed3a96f..22fb3d2 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -7,3 +7,26 @@ class ExampleModel(models.Model): def __str__(self): return self.title + + +class Tag(models.Model): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +class AllFieldsModel(models.Model): + title = models.CharField(max_length=100) + created_at = models.DateTimeField(auto_now_add=True) + updated_date = models.DateField(auto_now=True) + updated_time = models.TimeField(auto_now=True) + age = models.IntegerField() + is_active = models.BooleanField(default=True) + tags = models.ManyToManyField(Tag, related_name="all_fields") + + def __str__(self): + return self.title + + def get_tag_names(self): + return [tag.name for tag in self.tags.all()] diff --git a/tests/testapp/serializers.py b/tests/testapp/serializers.py index ceb261c..a138c7e 100644 --- a/tests/testapp/serializers.py +++ b/tests/testapp/serializers.py @@ -1,9 +1,25 @@ from rest_framework import serializers -from .models import ExampleModel +from .models import ExampleModel, AllFieldsModel class ExampleSerializer(serializers.ModelSerializer): class Meta: model = ExampleModel fields = ("title", "description") + + +class AllFieldsSerializer(serializers.ModelSerializer): + tags = serializers.ListField(source="get_tag_names") + + class Meta: + model = AllFieldsModel + fields = ( + "title", + "created_at", + "updated_date", + "updated_time", + "age", + "is_active", + "tags", + ) diff --git a/tests/testapp/views.py b/tests/testapp/views.py index 87eea78..33f5fc7 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -2,8 +2,8 @@ from drf_excel.mixins import XLSXFileMixin from drf_excel.renderers import XLSXRenderer -from .models import ExampleModel -from .serializers import ExampleSerializer +from .models import ExampleModel, AllFieldsModel +from .serializers import ExampleSerializer, AllFieldsSerializer class ExampleViewSet(XLSXFileMixin, ReadOnlyModelViewSet): @@ -11,3 +11,10 @@ class ExampleViewSet(XLSXFileMixin, ReadOnlyModelViewSet): serializer_class = ExampleSerializer renderer_classes = (XLSXRenderer,) filename = "my_export.xlsx" + + +class AllFieldsViewSet(XLSXFileMixin, ReadOnlyModelViewSet): + queryset = AllFieldsModel.objects.all() + serializer_class = AllFieldsSerializer + renderer_classes = (XLSXRenderer,) + filename = "al_fileds.xlsx" diff --git a/tests/urls.py b/tests/urls.py index feedffb..d2a8ebb 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,8 +1,9 @@ from rest_framework import routers -from .testapp.views import ExampleViewSet +from .testapp.views import ExampleViewSet, AllFieldsViewSet router = routers.SimpleRouter() router.register(r"examples", ExampleViewSet) +router.register(r"all-fields", AllFieldsViewSet) urlpatterns = router.urls diff --git a/tox.ini b/tox.ini index 4dba901..03ea3db 100644 --- a/tox.ini +++ b/tox.ini @@ -18,5 +18,6 @@ deps = pytest-django pytest-cov django-coverage-plugin + time-machine commands = {posargs:python -m pytest} From d2d88ad52a12af79763b594f5191cd6cd7e6a59e Mon Sep 17 00:00:00 2001 From: Bruno Alla Date: Thu, 10 Oct 2024 17:41:54 +0100 Subject: [PATCH 2/2] Add tests for renderers --- drf_excel/renderers.py | 6 ++--- tests/conftest.py | 13 ++++++++++ tests/test_renderers.py | 50 +++++++++++++++++++++++++++++++++++++ tests/test_viewset_mixin.py | 21 +++++++--------- tox.ini | 1 + 5 files changed, 76 insertions(+), 15 deletions(-) create mode 100644 tests/test_renderers.py diff --git a/drf_excel/renderers.py b/drf_excel/renderers.py index 87fe774..5aa0e9e 100644 --- a/drf_excel/renderers.py +++ b/drf_excel/renderers.py @@ -56,12 +56,12 @@ def render(self, data, accepted_media_type=None, renderer_context=None): """ Render `data` into XLSX workbook, returning a workbook. """ - if not self._check_validation_data(data): - return json.dumps(data) - if data is None: return bytes() + if not self._check_validation_data(data): + return json.dumps(data) + wb = Workbook() self.ws = wb.active diff --git a/tests/conftest.py b/tests/conftest.py index e641348..79cb2f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,8 @@ +import io +from typing import Union, Callable + import pytest +from openpyxl.reader.excel import load_workbook from openpyxl.workbook import Workbook from openpyxl.worksheet.worksheet import Worksheet @@ -11,3 +15,12 @@ def workbook() -> Workbook: @pytest.fixture def worksheet(workbook: Workbook) -> Worksheet: return Worksheet(workbook) + + +@pytest.fixture +def workbook_reader() -> Callable[[Union[bytes, str]], Workbook]: + def reader_func(buffer: Union[bytes, str]) -> Workbook: + io_buffer = io.BytesIO(buffer) + return load_workbook(io_buffer, read_only=True) + + return reader_func diff --git a/tests/test_renderers.py b/tests/test_renderers.py new file mode 100644 index 0000000..61cfabd --- /dev/null +++ b/tests/test_renderers.py @@ -0,0 +1,50 @@ +from PIL import Image +from rest_framework import serializers +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response + +from drf_excel.renderers import XLSXRenderer + + +class MySerializer(serializers.Serializer): + title = serializers.CharField() + + +class MyBaseView(GenericAPIView): + serializer_class = MySerializer + + def retrieve(self, request, *args, **kwargs): + return Response({"title": "example"}) + + +class TestXLSXRenderer: + renderer = XLSXRenderer() + + def test_validation_error(self): + assert self.renderer.render({"detail": "invalid"}) == '{"detail": "invalid"}' + + def test_none(self): + assert self.renderer.render(None) == b"" + + def test_with_header_attribute(self, tmp_path, workbook_reader): + image_path = tmp_path / "image.png" + with Image.new(mode="RGB", size=(100, 100), color="blue") as img: + img.save(image_path, format="png") + + class MyView(MyBaseView): + header = { + "use_header": True, + "header_title": "My Header", + "tab_title": "My Tab", + "img": str(image_path), + "style": {"font": {"name": "Arial"}}, + } + + result = self.renderer.render({}, renderer_context={"view": MyView}) + wb = workbook_reader(result) + sheet = wb.worksheets[0] + rows = list(sheet.rows) + assert len(rows) == 1 + row0_col0 = rows[0][0] + assert row0_col0.value == "My Header" + assert row0_col0.font.name == "Arial" diff --git a/tests/test_viewset_mixin.py b/tests/test_viewset_mixin.py index ad1553f..de8801a 100644 --- a/tests/test_viewset_mixin.py +++ b/tests/test_viewset_mixin.py @@ -1,14 +1,11 @@ -import io import datetime as dt import pytest -from openpyxl.reader.excel import load_workbook from rest_framework.test import APIClient from time_machine import TimeMachineFixture from tests.testapp.models import ExampleModel, AllFieldsModel, Tag - pytestmark = pytest.mark.django_db @@ -17,7 +14,7 @@ def api_client(): return APIClient() -def test_simple_viewset_model(api_client): +def test_simple_viewset_model(api_client, workbook_reader): ExampleModel.objects.create(title="test 1", description="This is a test") ExampleModel.objects.create(title="test 2", description="Another test") ExampleModel.objects.create(title="test 3", description="Testing this out") @@ -33,11 +30,10 @@ def test_simple_viewset_model(api_client): response.headers["content-disposition"] == "attachment; filename=my_export.xlsx" ) - workbook_buffer = io.BytesIO(response.content) - workbook = load_workbook(workbook_buffer, read_only=True) + wb = workbook_reader(response.content) - assert len(workbook.worksheets) == 1 - sheet = workbook.worksheets[0] + assert len(wb.worksheets) == 1 + sheet = wb.worksheets[0] rows = list(sheet.rows) assert len(rows) == 4 r0, r1, r2, r3 = rows @@ -59,7 +55,9 @@ def test_simple_viewset_model(api_client): assert r3[1].value == "Testing this out" -def test_all_fields_viewset(api_client, time_machine: TimeMachineFixture): +def test_all_fields_viewset( + api_client, time_machine: TimeMachineFixture, workbook_reader +): time_machine.move_to(dt.datetime(2023, 9, 10, 15, 44, 37)) instance = AllFieldsModel.objects.create(title="Hello", age=36, is_active=True) instance.tags.set( @@ -71,9 +69,8 @@ def test_all_fields_viewset(api_client, time_machine: TimeMachineFixture): response = api_client.get("/all-fields/") assert response.status_code == 200 - workbook_buffer = io.BytesIO(response.content) - workbook = load_workbook(workbook_buffer, read_only=True) - sheet = workbook.worksheets[0] + wb = workbook_reader(response.content) + sheet = wb.worksheets[0] rows = list(sheet.rows) assert len(rows) == 2 r0, r1 = rows diff --git a/tox.ini b/tox.ini index 03ea3db..22f288d 100644 --- a/tox.ini +++ b/tox.ini @@ -13,6 +13,7 @@ deps = djangorestframework openpyxl + Pillow pytest pytest-django