diff --git a/drf_excel/renderers.py b/drf_excel/renderers.py index 87fe774..5aa0e9e 100644 --- a/drf_excel/renderers.py +++ b/drf_excel/renderers.py @@ -56,12 +56,12 @@ def render(self, data, accepted_media_type=None, renderer_context=None): """ Render `data` into XLSX workbook, returning a workbook. """ - if not self._check_validation_data(data): - return json.dumps(data) - if data is None: return bytes() + if not self._check_validation_data(data): + return json.dumps(data) + wb = Workbook() self.ws = wb.active diff --git a/tests/conftest.py b/tests/conftest.py index e641348..79cb2f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,8 @@ +import io +from typing import Union, Callable + import pytest +from openpyxl.reader.excel import load_workbook from openpyxl.workbook import Workbook from openpyxl.worksheet.worksheet import Worksheet @@ -11,3 +15,12 @@ def workbook() -> Workbook: @pytest.fixture def worksheet(workbook: Workbook) -> Worksheet: return Worksheet(workbook) + + +@pytest.fixture +def workbook_reader() -> Callable[[Union[bytes, str]], Workbook]: + def reader_func(buffer: Union[bytes, str]) -> Workbook: + io_buffer = io.BytesIO(buffer) + return load_workbook(io_buffer, read_only=True) + + return reader_func diff --git a/tests/test_renderers.py b/tests/test_renderers.py new file mode 100644 index 0000000..61cfabd --- /dev/null +++ b/tests/test_renderers.py @@ -0,0 +1,50 @@ +from PIL import Image +from rest_framework import serializers +from rest_framework.generics import GenericAPIView +from rest_framework.response import Response + +from drf_excel.renderers import XLSXRenderer + + +class MySerializer(serializers.Serializer): + title = serializers.CharField() + + +class MyBaseView(GenericAPIView): + serializer_class = MySerializer + + def retrieve(self, request, *args, **kwargs): + return Response({"title": "example"}) + + +class TestXLSXRenderer: + renderer = XLSXRenderer() + + def test_validation_error(self): + assert self.renderer.render({"detail": "invalid"}) == '{"detail": "invalid"}' + + def test_none(self): + assert self.renderer.render(None) == b"" + + def test_with_header_attribute(self, tmp_path, workbook_reader): + image_path = tmp_path / "image.png" + with Image.new(mode="RGB", size=(100, 100), color="blue") as img: + img.save(image_path, format="png") + + class MyView(MyBaseView): + header = { + "use_header": True, + "header_title": "My Header", + "tab_title": "My Tab", + "img": str(image_path), + "style": {"font": {"name": "Arial"}}, + } + + result = self.renderer.render({}, renderer_context={"view": MyView}) + wb = workbook_reader(result) + sheet = wb.worksheets[0] + rows = list(sheet.rows) + assert len(rows) == 1 + row0_col0 = rows[0][0] + assert row0_col0.value == "My Header" + assert row0_col0.font.name == "Arial" diff --git a/tests/test_viewset_mixin.py b/tests/test_viewset_mixin.py index bc15090..de8801a 100644 --- a/tests/test_viewset_mixin.py +++ b/tests/test_viewset_mixin.py @@ -1,10 +1,12 @@ -import io +import datetime as dt import pytest -from openpyxl.reader.excel import load_workbook from rest_framework.test import APIClient +from time_machine import TimeMachineFixture -from tests.testapp.models import ExampleModel +from tests.testapp.models import ExampleModel, AllFieldsModel, Tag + +pytestmark = pytest.mark.django_db @pytest.fixture @@ -12,8 +14,7 @@ def api_client(): return APIClient() -@pytest.mark.django_db -def test_simple_viewset_model(api_client): +def test_simple_viewset_model(api_client, workbook_reader): ExampleModel.objects.create(title="test 1", description="This is a test") ExampleModel.objects.create(title="test 2", description="Another test") ExampleModel.objects.create(title="test 3", description="Testing this out") @@ -29,11 +30,10 @@ def test_simple_viewset_model(api_client): response.headers["content-disposition"] == "attachment; filename=my_export.xlsx" ) - workbook_buffer = io.BytesIO(response.content) - workbook = load_workbook(workbook_buffer, read_only=True) + wb = workbook_reader(response.content) - assert len(workbook.worksheets) == 1 - sheet = workbook.worksheets[0] + assert len(wb.worksheets) == 1 + sheet = wb.worksheets[0] rows = list(sheet.rows) assert len(rows) == 4 r0, r1, r2, r3 = rows @@ -53,3 +53,43 @@ def test_simple_viewset_model(api_client): assert len(r3) == 2 assert r3[0].value == "test 3" assert r3[1].value == "Testing this out" + + +def test_all_fields_viewset( + api_client, time_machine: TimeMachineFixture, workbook_reader +): + time_machine.move_to(dt.datetime(2023, 9, 10, 15, 44, 37)) + instance = AllFieldsModel.objects.create(title="Hello", age=36, is_active=True) + instance.tags.set( + [ + Tag.objects.create(name="test"), + Tag.objects.create(name="example"), + ] + ) + response = api_client.get("/all-fields/") + assert response.status_code == 200 + + wb = workbook_reader(response.content) + sheet = wb.worksheets[0] + rows = list(sheet.rows) + assert len(rows) == 2 + r0, r1 = rows + + assert [col.value for col in r0] == [ + "title", + "created_at", + "updated_date", + "updated_time", + "age", + "is_active", + "tags", + ] + assert [col.value for col in r1] == [ + "Hello", + dt.datetime(2023, 9, 10, 15, 44, 37), + dt.datetime(2023, 9, 10, 0, 0), + dt.time(15, 44, 37), + 36, + True, + "test, example", + ] diff --git a/tests/testapp/models.py b/tests/testapp/models.py index ed3a96f..22fb3d2 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -7,3 +7,26 @@ class ExampleModel(models.Model): def __str__(self): return self.title + + +class Tag(models.Model): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +class AllFieldsModel(models.Model): + title = models.CharField(max_length=100) + created_at = models.DateTimeField(auto_now_add=True) + updated_date = models.DateField(auto_now=True) + updated_time = models.TimeField(auto_now=True) + age = models.IntegerField() + is_active = models.BooleanField(default=True) + tags = models.ManyToManyField(Tag, related_name="all_fields") + + def __str__(self): + return self.title + + def get_tag_names(self): + return [tag.name for tag in self.tags.all()] diff --git a/tests/testapp/serializers.py b/tests/testapp/serializers.py index ceb261c..a138c7e 100644 --- a/tests/testapp/serializers.py +++ b/tests/testapp/serializers.py @@ -1,9 +1,25 @@ from rest_framework import serializers -from .models import ExampleModel +from .models import ExampleModel, AllFieldsModel class ExampleSerializer(serializers.ModelSerializer): class Meta: model = ExampleModel fields = ("title", "description") + + +class AllFieldsSerializer(serializers.ModelSerializer): + tags = serializers.ListField(source="get_tag_names") + + class Meta: + model = AllFieldsModel + fields = ( + "title", + "created_at", + "updated_date", + "updated_time", + "age", + "is_active", + "tags", + ) diff --git a/tests/testapp/views.py b/tests/testapp/views.py index 87eea78..33f5fc7 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -2,8 +2,8 @@ from drf_excel.mixins import XLSXFileMixin from drf_excel.renderers import XLSXRenderer -from .models import ExampleModel -from .serializers import ExampleSerializer +from .models import ExampleModel, AllFieldsModel +from .serializers import ExampleSerializer, AllFieldsSerializer class ExampleViewSet(XLSXFileMixin, ReadOnlyModelViewSet): @@ -11,3 +11,10 @@ class ExampleViewSet(XLSXFileMixin, ReadOnlyModelViewSet): serializer_class = ExampleSerializer renderer_classes = (XLSXRenderer,) filename = "my_export.xlsx" + + +class AllFieldsViewSet(XLSXFileMixin, ReadOnlyModelViewSet): + queryset = AllFieldsModel.objects.all() + serializer_class = AllFieldsSerializer + renderer_classes = (XLSXRenderer,) + filename = "al_fileds.xlsx" diff --git a/tests/urls.py b/tests/urls.py index feedffb..d2a8ebb 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,8 +1,9 @@ from rest_framework import routers -from .testapp.views import ExampleViewSet +from .testapp.views import ExampleViewSet, AllFieldsViewSet router = routers.SimpleRouter() router.register(r"examples", ExampleViewSet) +router.register(r"all-fields", AllFieldsViewSet) urlpatterns = router.urls diff --git a/tox.ini b/tox.ini index 4dba901..22f288d 100644 --- a/tox.ini +++ b/tox.ini @@ -13,10 +13,12 @@ deps = djangorestframework openpyxl + Pillow pytest pytest-django pytest-cov django-coverage-plugin + time-machine commands = {posargs:python -m pytest}