Skip to content

Commit

Permalink
feat: Support cookie authentication (#4662)
Browse files Browse the repository at this point in the history
Co-authored-by: Kyle Johnson <[email protected]>
  • Loading branch information
khvn26 and kyle-ssg authored Oct 3, 2024
1 parent d6c6004 commit e65c8da
Show file tree
Hide file tree
Showing 25 changed files with 452 additions and 67 deletions.
2 changes: 1 addition & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ POETRY_VERSION ?= 1.8.3

GUNICORN_LOGGER_CLASS ?= util.logging.GunicornJsonCapableLogger

SAML_REVISION ?= v1.6.3
SAML_REVISION ?= v1.6.4
RBAC_REVISION ?= v0.8.0

-include .env-local
Expand Down
45 changes: 30 additions & 15 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"rest_framework.authtoken",
# Used for managing api keys
"rest_framework_api_key",
"rest_framework_simplejwt.token_blacklist",
"djoser",
"django.contrib.sites",
"custom_auth",
Expand Down Expand Up @@ -254,6 +255,7 @@
REST_FRAMEWORK = {
"DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"],
"DEFAULT_AUTHENTICATION_CLASSES": (
"custom_auth.jwt_cookie.authentication.JWTCookieAuthentication",
"rest_framework.authentication.TokenAuthentication",
"api_keys.authentication.MasterAPIKeyAuthentication",
),
Expand Down Expand Up @@ -416,19 +418,6 @@

MEDIA_URL = "/media/" # unused but needs to be different from STATIC_URL in django 3

# CORS settings

CORS_ORIGIN_ALLOW_ALL = True
FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS = env.list(
"FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS", default=["sentry-trace"]
)
CORS_ALLOW_HEADERS = [
*default_headers,
*FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS,
"X-Environment-Key",
"X-E2E-Test-Auth-Token",
]

DEFAULT_FROM_EMAIL = env("SENDER_EMAIL", default="[email protected]")
EMAIL_CONFIGURATION = {
# Invitations with name is anticipated to take two arguments. The persons name and the
Expand Down Expand Up @@ -826,6 +815,16 @@
"user_create": USER_CREATE_PERMISSIONS,
},
}
SIMPLE_JWT = {
"AUTH_TOKEN_CLASSES": ["rest_framework_simplejwt.tokens.SlidingToken"],
"SLIDING_TOKEN_LIFETIME": timedelta(
minutes=env.int(
"COOKIE_AUTH_JWT_ACCESS_TOKEN_LIFETIME_MINUTES",
default=10 * 60,
)
),
"SIGNING_KEY": env.str("COOKIE_AUTH_JWT_SIGNING_KEY", default=SECRET_KEY),
}

# Github OAuth credentials
GITHUB_CLIENT_ID = env.str("GITHUB_CLIENT_ID", default="")
Expand Down Expand Up @@ -907,8 +906,6 @@
SENTRY_API_KEY = env("SENTRY_API_KEY", default=None)
AMPLITUDE_API_KEY = env("AMPLITUDE_API_KEY", default=None)
ENABLE_FLAGSMITH_REALTIME = env.bool("ENABLE_FLAGSMITH_REALTIME", default=False)
USE_SECURE_COOKIES = env.bool("USE_SECURE_COOKIES", default=True)
COOKIE_SAME_SITE = env.str("COOKIE_SAME_SITE", default="none")

# Set this to enable create organisation for only superusers
RESTRICT_ORG_CREATE_TO_SUPERUSERS = env.bool("RESTRICT_ORG_CREATE_TO_SUPERUSERS", False)
Expand Down Expand Up @@ -1038,6 +1035,24 @@

DISABLE_INVITE_LINKS = env.bool("DISABLE_INVITE_LINKS", False)
PREVENT_SIGNUP = env.bool("PREVENT_SIGNUP", default=False)
COOKIE_AUTH_ENABLED = env.bool("COOKIE_AUTH_ENABLED", default=False)
USE_SECURE_COOKIES = env.bool("USE_SECURE_COOKIES", default=True)
COOKIE_SAME_SITE = env.str("COOKIE_SAME_SITE", default="none")

# CORS settings

CORS_ORIGIN_ALLOW_ALL = env.bool("CORS_ORIGIN_ALLOW_ALL", not COOKIE_AUTH_ENABLED)
CORS_ALLOW_CREDENTIALS = env.bool("CORS_ALLOW_CREDENTIALS", COOKIE_AUTH_ENABLED)
FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS = env.list(
"FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS", default=["sentry-trace"]
)
CORS_ALLOWED_ORIGINS = env.list("CORS_ALLOWED_ORIGINS", default=[])
CORS_ALLOW_HEADERS = [
*default_headers,
*FLAGSMITH_CORS_EXTRA_ALLOW_HEADERS,
"X-Environment-Key",
"X-E2E-Test-Auth-Token",
]

# use a separate boolean setting so that we add it to the API containers in environments
# where we're running the task processor, so we avoid creating unnecessary tasks
Expand Down
9 changes: 5 additions & 4 deletions api/app/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ def project_overrides(request):
"amplitude": "AMPLITUDE_API_KEY",
"api": "API_URL",
"assetURL": "ASSET_URL",
"cookieAuthEnabled": "COOKIE_AUTH_ENABLED",
"cookieSameSite": "COOKIE_SAME_SITE",
"crispChat": "CRISP_CHAT_API_KEY",
"disableAnalytics": "DISABLE_ANALYTICS_FEATURES",
"flagsmith": "FLAGSMITH_ON_FLAGSMITH_API_KEY",
"flagsmithAnalytics": "FLAGSMITH_ANALYTICS",
"flagsmithRealtime": "ENABLE_FLAGSMITH_REALTIME",
"flagsmithClientAPI": "FLAGSMITH_ON_FLAGSMITH_API_URL",
"ga": "GOOGLE_ANALYTICS_API_KEY",
"flagsmithRealtime": "ENABLE_FLAGSMITH_REALTIME",
"fpr": "FIRST_PROMOTER_ID",
"ga": "GOOGLE_ANALYTICS_API_KEY",
"githubAppURL": "GITHUB_APP_URL",
"headway": "HEADWAY_API_KEY",
"hideInviteLinks": "DISABLE_INVITE_LINKS",
"linkedinPartnerTracking": "LINKEDIN_PARTNER_TRACKING",
Expand All @@ -54,8 +57,6 @@ def project_overrides(request):
"preventSignup": "PREVENT_SIGNUP",
"sentry": "SENTRY_API_KEY",
"useSecureCookies": "USE_SECURE_COOKIES",
"cookieSameSite": "COOKIE_SAME_SITE",
"githubAppURL": "GITHUB_APP_URL",
}

override_data = {
Expand Down
20 changes: 13 additions & 7 deletions api/core/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

from django.conf import settings
from django.contrib.sites.models import Site
from django.contrib.sites import models as sites_models
from django.http import HttpRequest
from rest_framework.request import Request

Expand All @@ -11,12 +11,18 @@


def get_current_site_url(request: HttpRequest | Request | None = None) -> str:
if settings.DOMAIN_OVERRIDE:
domain = settings.DOMAIN_OVERRIDE
elif current_site := Site.objects.filter(id=settings.SITE_ID).first():
domain = current_site.domain
else:
domain = settings.DEFAULT_DOMAIN
if not (domain := settings.DOMAIN_OVERRIDE):
try:
domain = sites_models.Site.objects.get_current(request).domain
except sites_models.Site.DoesNotExist:
# For the rare case when `DOMAIN_OVERRIDE` was not set and no `Site` object present,
# store a default domain `Site` in the sites cache
# so it's correctly invalidated should the user decide to create own `Site` object.
domain = settings.DEFAULT_DOMAIN
sites_models.SITE_CACHE[settings.SITE_ID] = sites_models.Site(
name="Flagsmith",
domain=domain,
)

if request:
scheme = request.scheme
Expand Down
1 change: 1 addition & 0 deletions api/custom_auth/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class CustomAuthAppConfig(AppConfig):

def ready(self) -> None:
from custom_auth import tasks # noqa F401
from custom_auth.jwt_cookie import signals # noqa F401
Empty file.
17 changes: 17 additions & 0 deletions api/custom_auth/jwt_cookie/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rest_framework.request import Request
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.tokens import Token

from custom_auth.jwt_cookie.constants import JWT_SLIDING_COOKIE_KEY
from users.models import FFAdminUser


class JWTCookieAuthentication(JWTAuthentication):
def authenticate_header(self, request: Request) -> str:
return f'Cookie realm="{self.www_authenticate_realm}"'

def authenticate(self, request: Request) -> tuple[FFAdminUser, Token] | None:
if raw_token := request.COOKIES.get(JWT_SLIDING_COOKIE_KEY):
validated_token = self.get_validated_token(raw_token)
return self.get_user(validated_token), validated_token
return None
1 change: 1 addition & 0 deletions api/custom_auth/jwt_cookie/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
JWT_SLIDING_COOKIE_KEY = "jwt"
18 changes: 18 additions & 0 deletions api/custom_auth/jwt_cookie/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from django.conf import settings
from rest_framework.response import Response
from rest_framework_simplejwt.tokens import SlidingToken

from custom_auth.jwt_cookie.constants import JWT_SLIDING_COOKIE_KEY
from users.models import FFAdminUser


def authorise_response(user: FFAdminUser, response: Response) -> Response:
sliding_token = SlidingToken.for_user(user)
response.set_cookie(
JWT_SLIDING_COOKIE_KEY,
str(sliding_token),
httponly=True,
secure=settings.USE_SECURE_COOKIES,
samesite=settings.COOKIE_SAME_SITE,
)
return response
20 changes: 20 additions & 0 deletions api/custom_auth/jwt_cookie/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any
from urllib.parse import urlparse

from core.helpers import get_current_site_url
from corsheaders.signals import check_request_enabled
from django.dispatch import receiver
from django.http import HttpRequest


@receiver(check_request_enabled)
def cors_allow_current_site(request: HttpRequest, **kwargs: Any) -> bool:
# The signal is expected to only be dispatched:
# - When `settings.CORS_ORIGIN_ALLOW_ALL` is set to `False`.
# - For requests with `HTTP_ORIGIN` set.
origin_url = urlparse(request.META["HTTP_ORIGIN"])
current_site_url = urlparse(get_current_site_url(request))
return (
origin_url.scheme == current_site_url.scheme
and origin_url.netloc == current_site_url.netloc
)
15 changes: 15 additions & 0 deletions api/custom_auth/jwt_cookie/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from djoser.views import TokenDestroyView
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework_simplejwt.tokens import SlidingToken

from custom_auth.jwt_cookie.constants import JWT_SLIDING_COOKIE_KEY


class JWTSlidingTokenLogoutView(TokenDestroyView):
def post(self, request: Request) -> Response:
response = super().post(request)
if isinstance(jwt_token := request.auth, SlidingToken):
jwt_token.blacklist()
response.delete_cookie(JWT_SLIDING_COOKIE_KEY)
return response
16 changes: 13 additions & 3 deletions api/custom_auth/serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from django.conf import settings
from djoser.conf import settings as djoser_settings
from djoser.serializers import TokenCreateSerializer, UserCreateSerializer
Expand Down Expand Up @@ -73,13 +75,15 @@ def _validate_registration_invite(self, email: str, sign_up_type: str) -> None:


class CustomUserCreateSerializer(UserCreateSerializer, InviteLinkValidationMixin):
key = serializers.SerializerMethodField()
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if not settings.COOKIE_AUTH_ENABLED:
self.fields["key"] = serializers.SerializerMethodField()

class Meta(UserCreateSerializer.Meta):
fields = UserCreateSerializer.Meta.fields + (
"is_active",
"marketing_consent_given",
"key",
"uuid",
)
read_only_fields = ("is_active", "uuid")
Expand Down Expand Up @@ -115,8 +119,14 @@ def validate(self, attrs):
attrs["email"] = email.lower()
return attrs

def save(self) -> FFAdminUser:
instance = super().save()
if "view" in self.context:
self.context["view"].user = instance
return instance

@staticmethod
def get_key(instance):
def get_key(instance) -> str:
token, _ = Token.objects.get_or_create(user=instance)
return token.key

Expand Down
8 changes: 6 additions & 2 deletions api/custom_auth/urls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.urls import include, path
from djoser.views import TokenDestroyView
from rest_framework.routers import DefaultRouter

from custom_auth.jwt_cookie.views import JWTSlidingTokenLogoutView
from custom_auth.views import (
CustomAuthTokenLoginOrRequestMFACode,
CustomAuthTokenLoginWithMFACode,
Expand All @@ -26,7 +26,11 @@
CustomAuthTokenLoginWithMFACode.as_view(),
name="mfa-authtoken-login-code",
),
path("logout/", TokenDestroyView.as_view(), name="authtoken-logout"),
path(
"logout/",
JWTSlidingTokenLogoutView.as_view(),
name="jwt-logout",
),
path("", include(ffadmin_user_router.urls)),
path("token/", delete_token, name="delete-token"),
# NOTE: endpoints provided by `djoser.urls`
Expand Down
21 changes: 21 additions & 0 deletions api/custom_auth/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from django.conf import settings
from django.contrib.auth import user_logged_out
from django.utils.decorators import method_decorator
Expand All @@ -9,8 +11,10 @@
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.status import HTTP_204_NO_CONTENT
from rest_framework.throttling import ScopedRateThrottle

from custom_auth.jwt_cookie.services import authorise_response
from custom_auth.mfa.backends.application import CustomApplicationBackend
from custom_auth.mfa.trench.command.authenticate_second_factor import (
authenticate_second_step_command,
Expand All @@ -34,6 +38,7 @@ class CustomAuthTokenLoginOrRequestMFACode(TokenCreateView):
Class to handle throttling for login requests
"""

authentication_classes = []
throttle_classes = [ScopedRateThrottle]
throttle_scope = "login"

Expand All @@ -54,6 +59,8 @@ def post(self, request: Request) -> Response:
}
)
except MFAMethodDoesNotExistError:
if settings.COOKIE_AUTH_ENABLED:
return authorise_response(user, Response(status=HTTP_204_NO_CONTENT))
return self._action(serializer)


Expand All @@ -62,6 +69,7 @@ class CustomAuthTokenLoginWithMFACode(TokenCreateView):
Override class to add throttling
"""

authentication_classes = []
throttle_classes = [ScopedRateThrottle]
throttle_scope = "mfa_code"

Expand All @@ -74,6 +82,8 @@ def post(self, request: Request) -> Response:
ephemeral_token=serializer.validated_data["ephemeral_token"],
)
serializer.user = user
if settings.COOKIE_AUTH_ENABLED:
return authorise_response(user, Response(status=HTTP_204_NO_CONTENT))
return self._action(serializer)
except MFAValidationError as cause:
return ErrorResponse(error=cause, status=status.HTTP_401_UNAUTHORIZED)
Expand All @@ -96,6 +106,11 @@ def delete_token(request):
class FFAdminUserViewSet(UserViewSet):
throttle_scope = "signup"

def perform_authentication(self, request: Request) -> None:
if self.action == "create":
return
return super().perform_authentication(request)

def get_throttles(self):
"""
Used for throttling create(signup) action
Expand All @@ -105,6 +120,12 @@ def get_throttles(self):
throttles = [ScopedRateThrottle()]
return throttles

def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
response = super().create(request, *args, **kwargs)
if settings.COOKIE_AUTH_ENABLED:
authorise_response(self.user, response)
return response

def perform_destroy(self, instance):
instance.delete(
delete_orphan_organisations=self.request.data.get(
Expand Down
Loading

0 comments on commit e65c8da

Please sign in to comment.