-
-
Notifications
You must be signed in to change notification settings - Fork 536
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
234 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import functools | ||
from typing import Any | ||
from typing import Callable | ||
from typing import cast | ||
from typing import Optional | ||
from typing import TypeVar | ||
|
||
from django.http import HttpRequest | ||
from django.http import HttpResponseBase | ||
|
||
from corsheaders.conf import conf as _conf | ||
from corsheaders.conf import Settings | ||
from corsheaders.middleware import CorsMiddleware | ||
|
||
F = TypeVar("F", bound=Callable[..., HttpResponseBase]) | ||
|
||
|
||
def cors(func: Optional[F] = None, *, conf: Settings = _conf) -> F | Callable[[F], F]: | ||
if func is None: | ||
return cast(Callable[[F], F], functools.partial(cors, conf=conf)) | ||
|
||
assert callable(func) | ||
|
||
if asyncio.iscoroutinefunction(func): | ||
|
||
async def inner( | ||
_request: HttpRequest, *args: Any, **kwargs: Any | ||
) -> HttpResponseBase: | ||
async def get_response(request: HttpRequest) -> HttpResponseBase: | ||
return await func(request, *args, **kwargs) | ||
|
||
return await CorsMiddleware(get_response, conf=conf)(_request) | ||
|
||
else: | ||
|
||
def inner(_request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase: | ||
def get_response(request: HttpRequest) -> HttpResponseBase: | ||
return func(request, *args, **kwargs) | ||
|
||
return CorsMiddleware(get_response, conf=conf)(_request) | ||
|
||
wrapper = functools.wraps(func)(inner) | ||
wrapper._skip_cors_middleware = True # type: ignore [attr-defined] | ||
return cast(F, wrapper) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from __future__ import annotations | ||
|
||
from django.test import TestCase | ||
from django.test.utils import modify_settings | ||
from django.test.utils import override_settings | ||
|
||
from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN | ||
|
||
|
||
@modify_settings( | ||
MIDDLEWARE={ | ||
"remove": "corsheaders.middleware.CorsMiddleware", | ||
} | ||
) | ||
@override_settings(CORS_ALLOWED_ORIGINS=["https://example.com"]) | ||
class CorsDecoratorsTestCase(TestCase): | ||
def test_get_no_origin(self): | ||
resp = self.client.get("/decorated/hello/") | ||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp | ||
assert resp.content == b"Decorated: hello" | ||
|
||
def test_get_not_in_allowed_origins(self): | ||
resp = self.client.get( | ||
"/decorated/hello/", | ||
HTTP_ORIGIN="https://example.net", | ||
) | ||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp | ||
assert resp.content == b"Decorated: hello" | ||
|
||
def test_get_in_allowed_origins_preflight(self): | ||
resp = self.client.options( | ||
"/decorated/hello/", | ||
HTTP_ORIGIN="https://example.com", | ||
HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"" | ||
|
||
def test_get_in_allowed_origins(self): | ||
resp = self.client.get( | ||
"/decorated/hello/", | ||
HTTP_ORIGIN="https://example.com", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"Decorated: hello" | ||
|
||
async def test_async_get_not_in_allowed_origins(self): | ||
resp = await self.async_client.get( | ||
"/async-decorated/hello/", | ||
origin="https://example.org", | ||
) | ||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp | ||
assert resp.content == b"Async Decorated: hello" | ||
|
||
async def test_async_get_in_allowed_origins_preflight(self): | ||
resp = await self.async_client.options( | ||
"/async-decorated/hello/", | ||
origin="https://example.com", | ||
access_control_request_method="GET", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"" | ||
|
||
async def test_async_get_in_allowed_origins(self): | ||
resp = await self.async_client.get( | ||
"/async-decorated/hello/", | ||
origin="https://example.com", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"Async Decorated: hello" | ||
|
||
|
||
class CorsDecoratorsWithConfTestCase(TestCase): | ||
def test_get_no_origin(self): | ||
resp = self.client.get("/decorated-with-conf/hello/") | ||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp | ||
assert resp.content == b"Decorated (with conf): hello" | ||
|
||
def test_get_not_in_allowed_origins(self): | ||
resp = self.client.get( | ||
"/decorated-with-conf/hello/", HTTP_ORIGIN="https://example.net" | ||
) | ||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp | ||
assert resp.content == b"Decorated (with conf): hello" | ||
|
||
def test_get_in_allowed_origins_preflight(self): | ||
resp = self.client.options( | ||
"/decorated-with-conf/hello/", | ||
HTTP_ORIGIN="https://example.com", | ||
HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"Decorated (with conf): hello" | ||
|
||
def test_get_in_allowed_origins(self): | ||
resp = self.client.get( | ||
"/decorated-with-conf/hello/", | ||
HTTP_ORIGIN="https://example.com", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"Decorated (with conf): hello" | ||
|
||
async def test_async_get_not_in_allowed_origins(self): | ||
resp = await self.async_client.get( | ||
"/async-decorated-with-conf/hello/", | ||
origin="https://example.org", | ||
) | ||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp | ||
assert resp.content == b"Async Decorated (with conf): hello" | ||
|
||
async def test_async_get_in_allowed_origins_preflight(self): | ||
resp = await self.async_client.options( | ||
"/async-decorated-with-conf/hello/", | ||
origin="https://example.com", | ||
access_control_request_method="GET", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"" | ||
|
||
async def test_async_get_in_allowed_origins(self): | ||
resp = await self.async_client.get( | ||
"/async-decorated-with-conf/hello/", | ||
origin="https://example.com", | ||
) | ||
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" | ||
assert resp.content == b"Async Decorated (with conf): hello" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters