From 38445041d080a79772421ce9ff454a746de42673 Mon Sep 17 00:00:00 2001 From: Jaap Roes Date: Thu, 5 Oct 2023 12:59:05 +0200 Subject: [PATCH] Add cors decorator --- src/corsheaders/decorators.py | 47 +++++++++++++ src/corsheaders/middleware.py | 48 +++++++++---- tests/test_decorators.py | 126 ++++++++++++++++++++++++++++++++++ tests/urls.py | 4 ++ tests/views.py | 23 +++++++ 5 files changed, 234 insertions(+), 14 deletions(-) create mode 100644 src/corsheaders/decorators.py create mode 100644 tests/test_decorators.py diff --git a/src/corsheaders/decorators.py b/src/corsheaders/decorators.py new file mode 100644 index 00000000..65ec2b43 --- /dev/null +++ b/src/corsheaders/decorators.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import asyncio +import functools +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import TypeVar + +from django.http import HttpRequest +from django.http import HttpResponseBase + +from corsheaders.conf import conf as _conf +from corsheaders.conf import Settings +from corsheaders.middleware import CorsMiddleware + +F = TypeVar("F", bound=Callable[..., HttpResponseBase]) + + +def cors(func: Optional[F] = None, *, conf: Settings = _conf) -> F | Callable[[F], F]: + if func is None: + return cast(Callable[[F], F], functools.partial(cors, conf=conf)) + + assert callable(func) + + if asyncio.iscoroutinefunction(func): + + async def inner( + _request: HttpRequest, *args: Any, **kwargs: Any + ) -> HttpResponseBase: + async def get_response(request: HttpRequest) -> HttpResponseBase: + return await func(request, *args, **kwargs) + + return await CorsMiddleware(get_response, conf=conf)(_request) + + else: + + def inner(_request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase: + def get_response(request: HttpRequest) -> HttpResponseBase: + return func(request, *args, **kwargs) + + return CorsMiddleware(get_response, conf=conf)(_request) + + wrapper = functools.wraps(func)(inner) + wrapper._skip_cors_middleware = True # type: ignore [attr-defined] + return cast(F, wrapper) diff --git a/src/corsheaders/middleware.py b/src/corsheaders/middleware.py index 500a5096..f307e04d 100644 --- a/src/corsheaders/middleware.py +++ b/src/corsheaders/middleware.py @@ -2,6 +2,7 @@ import asyncio import re +from typing import Any from typing import Awaitable from typing import Callable from urllib.parse import SplitResult @@ -54,22 +55,40 @@ def __call__( ) -> HttpResponseBase | Awaitable[HttpResponseBase]: if self._is_coroutine: return self.__acall__(request) - response: HttpResponseBase | None = self.check_preflight(request) - if response is None: - result = self.get_response(request) - assert isinstance(result, HttpResponseBase) - response = result - self.add_response_headers(request, response) - return response + result = self.get_response(request) + assert isinstance(result, HttpResponseBase) + response = result + if getattr(request, "_cors_preflight_done", False): + return response + else: + # Request wasn't processed (e.g. because of a 404) + return self.add_response_headers( + request, self.check_preflight(request) or response + ) async def __acall__(self, request: HttpRequest) -> HttpResponseBase: - response = self.check_preflight(request) - if response is None: - result = self.get_response(request) - assert not isinstance(result, HttpResponseBase) - response = await result - self.add_response_headers(request, response) - return response + result = self.get_response(request) + assert not isinstance(result, HttpResponseBase) + response = await result + if getattr(response, "_cors_processing_done", False): + return response + else: + # View wasn't processed (e.g. because of a 404) + return self.add_response_headers( + request, self.check_preflight(request) or response + ) + + def process_view( + self, + request: HttpRequest, + callback: Callable[[HttpRequest], HttpResponseBase], + callback_args: Any, + callback_kwargs: Any, + ) -> HttpResponseBase | None: + if getattr(callback, "_skip_cors_middleware", False): + # View is decorated and will add CORS headers itself + return None + return self.check_preflight(request) def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None: """ @@ -90,6 +109,7 @@ def add_response_headers( """ Add the respective CORS headers """ + response._cors_processing_done = True enabled = getattr(request, "_cors_enabled", None) if enabled is None: enabled = self.is_enabled(request) diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..04c44cf7 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from django.test import TestCase +from django.test.utils import modify_settings +from django.test.utils import override_settings + +from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN + + +@modify_settings( + MIDDLEWARE={ + "remove": "corsheaders.middleware.CorsMiddleware", + } +) +@override_settings(CORS_ALLOWED_ORIGINS=["https://example.com"]) +class CorsDecoratorsTestCase(TestCase): + def test_get_no_origin(self): + resp = self.client.get("/decorated/hello/") + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated: hello" + + def test_get_not_in_allowed_origins(self): + resp = self.client.get( + "/decorated/hello/", + HTTP_ORIGIN="https://example.net", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated: hello" + + def test_get_in_allowed_origins_preflight(self): + resp = self.client.options( + "/decorated/hello/", + HTTP_ORIGIN="https://example.com", + HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + def test_get_in_allowed_origins(self): + resp = self.client.get( + "/decorated/hello/", + HTTP_ORIGIN="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Decorated: hello" + + async def test_async_get_not_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated/hello/", + origin="https://example.org", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Async Decorated: hello" + + async def test_async_get_in_allowed_origins_preflight(self): + resp = await self.async_client.options( + "/async-decorated/hello/", + origin="https://example.com", + access_control_request_method="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + async def test_async_get_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated/hello/", + origin="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Async Decorated: hello" + + +class CorsDecoratorsWithConfTestCase(TestCase): + def test_get_no_origin(self): + resp = self.client.get("/decorated-with-conf/hello/") + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated (with conf): hello" + + def test_get_not_in_allowed_origins(self): + resp = self.client.get( + "/decorated-with-conf/hello/", HTTP_ORIGIN="https://example.net" + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated (with conf): hello" + + def test_get_in_allowed_origins_preflight(self): + resp = self.client.options( + "/decorated-with-conf/hello/", + HTTP_ORIGIN="https://example.com", + HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Decorated (with conf): hello" + + def test_get_in_allowed_origins(self): + resp = self.client.get( + "/decorated-with-conf/hello/", + HTTP_ORIGIN="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Decorated (with conf): hello" + + async def test_async_get_not_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated-with-conf/hello/", + origin="https://example.org", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Async Decorated (with conf): hello" + + async def test_async_get_in_allowed_origins_preflight(self): + resp = await self.async_client.options( + "/async-decorated-with-conf/hello/", + origin="https://example.com", + access_control_request_method="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + async def test_async_get_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated-with-conf/hello/", + origin="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Async Decorated (with conf): hello" diff --git a/tests/urls.py b/tests/urls.py index a790fb1d..c3f64407 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -7,6 +7,10 @@ urlpatterns = [ path("", views.index), path("async/", views.async_), + path("decorated//", views.decorated), + path("decorated-with-conf//", views.decorated_with_conf), + path("async-decorated//", views.async_decorated), + path("async-decorated-with-conf//", views.async_decorated_with_conf), path("unauthorized/", views.unauthorized), path("delete-enabled/", views.delete_enabled_attribute), ] diff --git a/tests/views.py b/tests/views.py index 06e257a5..afd5ccf4 100644 --- a/tests/views.py +++ b/tests/views.py @@ -5,6 +5,9 @@ from django.http import HttpResponse from django.views.decorators.http import require_GET +from corsheaders.decorators import cors +from corsheaders.conf import Settings + @require_GET def index(request): @@ -15,6 +18,26 @@ async def async_(request): return HttpResponse("Asynchronous") +@cors +def decorated(request, slug): + return HttpResponse(f"Decorated: {slug}") + + +@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"])) +def decorated_with_conf(request, slug): + return HttpResponse(f"Decorated (with conf): {slug}") + + +@cors +async def async_decorated(request, slug): + return HttpResponse(f"Async Decorated: {slug}") + + +@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"])) +async def async_decorated_with_conf(request, slug): + return HttpResponse(f"Async Decorated (with conf): {slug}") + + def unauthorized(request): return HttpResponse("Unauthorized", status=HTTPStatus.UNAUTHORIZED)