Skip to content

Commit

Permalink
Add cors decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
jaap3 committed Oct 5, 2023
1 parent 75cc53c commit 3844504
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 14 deletions.
47 changes: 47 additions & 0 deletions src/corsheaders/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import asyncio
import functools
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
from typing import TypeVar

from django.http import HttpRequest
from django.http import HttpResponseBase

from corsheaders.conf import conf as _conf
from corsheaders.conf import Settings
from corsheaders.middleware import CorsMiddleware

F = TypeVar("F", bound=Callable[..., HttpResponseBase])


def cors(func: Optional[F] = None, *, conf: Settings = _conf) -> F | Callable[[F], F]:
if func is None:
return cast(Callable[[F], F], functools.partial(cors, conf=conf))

assert callable(func)

if asyncio.iscoroutinefunction(func):

async def inner(
_request: HttpRequest, *args: Any, **kwargs: Any
) -> HttpResponseBase:
async def get_response(request: HttpRequest) -> HttpResponseBase:
return await func(request, *args, **kwargs)

return await CorsMiddleware(get_response, conf=conf)(_request)

else:

def inner(_request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
def get_response(request: HttpRequest) -> HttpResponseBase:
return func(request, *args, **kwargs)

return CorsMiddleware(get_response, conf=conf)(_request)

wrapper = functools.wraps(func)(inner)
wrapper._skip_cors_middleware = True # type: ignore [attr-defined]
return cast(F, wrapper)
48 changes: 34 additions & 14 deletions src/corsheaders/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import re
from typing import Any
from typing import Awaitable
from typing import Callable
from urllib.parse import SplitResult
Expand Down Expand Up @@ -54,22 +55,40 @@ def __call__(
) -> HttpResponseBase | Awaitable[HttpResponseBase]:
if self._is_coroutine:
return self.__acall__(request)
response: HttpResponseBase | None = self.check_preflight(request)
if response is None:
result = self.get_response(request)
assert isinstance(result, HttpResponseBase)
response = result
self.add_response_headers(request, response)
return response
result = self.get_response(request)
assert isinstance(result, HttpResponseBase)
response = result
if getattr(request, "_cors_preflight_done", False):
return response
else:
# Request wasn't processed (e.g. because of a 404)
return self.add_response_headers(
request, self.check_preflight(request) or response
)

async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
response = self.check_preflight(request)
if response is None:
result = self.get_response(request)
assert not isinstance(result, HttpResponseBase)
response = await result
self.add_response_headers(request, response)
return response
result = self.get_response(request)
assert not isinstance(result, HttpResponseBase)
response = await result
if getattr(response, "_cors_processing_done", False):
return response
else:
# View wasn't processed (e.g. because of a 404)
return self.add_response_headers(
request, self.check_preflight(request) or response
)

def process_view(
self,
request: HttpRequest,
callback: Callable[[HttpRequest], HttpResponseBase],
callback_args: Any,
callback_kwargs: Any,
) -> HttpResponseBase | None:
if getattr(callback, "_skip_cors_middleware", False):
# View is decorated and will add CORS headers itself
return None
return self.check_preflight(request)

def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None:
"""
Expand All @@ -90,6 +109,7 @@ def add_response_headers(
"""
Add the respective CORS headers
"""
response._cors_processing_done = True
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
Expand Down
126 changes: 126 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

from django.test import TestCase
from django.test.utils import modify_settings
from django.test.utils import override_settings

from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN


@modify_settings(
MIDDLEWARE={
"remove": "corsheaders.middleware.CorsMiddleware",
}
)
@override_settings(CORS_ALLOWED_ORIGINS=["https://example.com"])
class CorsDecoratorsTestCase(TestCase):
def test_get_no_origin(self):
resp = self.client.get("/decorated/hello/")
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
assert resp.content == b"Decorated: hello"

def test_get_not_in_allowed_origins(self):
resp = self.client.get(
"/decorated/hello/",
HTTP_ORIGIN="https://example.net",
)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
assert resp.content == b"Decorated: hello"

def test_get_in_allowed_origins_preflight(self):
resp = self.client.options(
"/decorated/hello/",
HTTP_ORIGIN="https://example.com",
HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b""

def test_get_in_allowed_origins(self):
resp = self.client.get(
"/decorated/hello/",
HTTP_ORIGIN="https://example.com",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b"Decorated: hello"

async def test_async_get_not_in_allowed_origins(self):
resp = await self.async_client.get(
"/async-decorated/hello/",
origin="https://example.org",
)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
assert resp.content == b"Async Decorated: hello"

async def test_async_get_in_allowed_origins_preflight(self):
resp = await self.async_client.options(
"/async-decorated/hello/",
origin="https://example.com",
access_control_request_method="GET",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b""

async def test_async_get_in_allowed_origins(self):
resp = await self.async_client.get(
"/async-decorated/hello/",
origin="https://example.com",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b"Async Decorated: hello"


class CorsDecoratorsWithConfTestCase(TestCase):
def test_get_no_origin(self):
resp = self.client.get("/decorated-with-conf/hello/")
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
assert resp.content == b"Decorated (with conf): hello"

def test_get_not_in_allowed_origins(self):
resp = self.client.get(
"/decorated-with-conf/hello/", HTTP_ORIGIN="https://example.net"
)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
assert resp.content == b"Decorated (with conf): hello"

def test_get_in_allowed_origins_preflight(self):
resp = self.client.options(
"/decorated-with-conf/hello/",
HTTP_ORIGIN="https://example.com",
HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b"Decorated (with conf): hello"

def test_get_in_allowed_origins(self):
resp = self.client.get(
"/decorated-with-conf/hello/",
HTTP_ORIGIN="https://example.com",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b"Decorated (with conf): hello"

async def test_async_get_not_in_allowed_origins(self):
resp = await self.async_client.get(
"/async-decorated-with-conf/hello/",
origin="https://example.org",
)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
assert resp.content == b"Async Decorated (with conf): hello"

async def test_async_get_in_allowed_origins_preflight(self):
resp = await self.async_client.options(
"/async-decorated-with-conf/hello/",
origin="https://example.com",
access_control_request_method="GET",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b""

async def test_async_get_in_allowed_origins(self):
resp = await self.async_client.get(
"/async-decorated-with-conf/hello/",
origin="https://example.com",
)
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert resp.content == b"Async Decorated (with conf): hello"
4 changes: 4 additions & 0 deletions tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
urlpatterns = [
path("", views.index),
path("async/", views.async_),
path("decorated/<slug:slug>/", views.decorated),
path("decorated-with-conf/<slug:slug>/", views.decorated_with_conf),
path("async-decorated/<slug:slug>/", views.async_decorated),
path("async-decorated-with-conf/<slug:slug>/", views.async_decorated_with_conf),
path("unauthorized/", views.unauthorized),
path("delete-enabled/", views.delete_enabled_attribute),
]
23 changes: 23 additions & 0 deletions tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from django.http import HttpResponse
from django.views.decorators.http import require_GET

from corsheaders.decorators import cors
from corsheaders.conf import Settings


@require_GET
def index(request):
Expand All @@ -15,6 +18,26 @@ async def async_(request):
return HttpResponse("Asynchronous")


@cors
def decorated(request, slug):
return HttpResponse(f"Decorated: {slug}")


@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"]))
def decorated_with_conf(request, slug):
return HttpResponse(f"Decorated (with conf): {slug}")


@cors
async def async_decorated(request, slug):
return HttpResponse(f"Async Decorated: {slug}")


@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"]))
async def async_decorated_with_conf(request, slug):
return HttpResponse(f"Async Decorated (with conf): {slug}")


def unauthorized(request):
return HttpResponse("Unauthorized", status=HTTPStatus.UNAUTHORIZED)

Expand Down

0 comments on commit 3844504

Please sign in to comment.