diff --git a/readme.md b/readme.md index 2ff65399ac6a..59fb364df150 100644 --- a/readme.md +++ b/readme.md @@ -181,6 +181,7 @@ You can also provide individual variables as below. Note that if a `DATABASE_URL * `GA_TABLE_ID`: GA table ID (view) to query when looking for organisation usage * `ALLOWED_ADMIN_IP_ADDRESSES`: restrict access to the django admin console to a comma separated list of IP addresses (e.g. `127.0.0.1,127.0.0.2`) * `USER_CREATE_PERMISSIONS`: set the permissions for creating new users, using a comma separated list of djoser or rest_framework permissions. Use this to turn off public user creation for self hosting. e.g. `'djoser.permissions.CurrentUserOrAdmin'` Defaults to `'rest_framework.permissions.AllowAny'`. +* `ALLOW_REGISTRATION_WITHOUT_INVITE`: Determines whether users can register without an invite. Defaults to True. Set to False or 0 to disable. Note that if disabled, new users must be invited via email. * `ENABLE_EMAIL_ACTIVATION`: new user registration will go via email activation flow, default False * `SENTRY_SDK_DSN`: If using Sentry, set the project DSN here. * `SENTRY_TRACE_SAMPLE_RATE`: Float. If using Sentry, sets the trace sample rate. Defaults to 1.0. diff --git a/src/app/settings/common.py b/src/app/settings/common.py index 7b7a4ab13b47..c21538901174 100644 --- a/src/app/settings/common.py +++ b/src/app/settings/common.py @@ -422,6 +422,11 @@ GITHUB_CLIENT_ID = env.str("GITHUB_CLIENT_ID", default="") GITHUB_CLIENT_SECRET = env.str("GITHUB_CLIENT_SECRET", default="") +# Allow the configuration of registration via OAuth +ALLOW_REGISTRATION_WITHOUT_INVITE = env.bool( + "ALLOW_REGISTRATION_WITHOUT_INVITE", default=True +) + # Django Axes settings ENABLE_AXES = env.bool("ENABLE_AXES", default=False) if ENABLE_AXES: diff --git a/src/custom_auth/constants.py b/src/custom_auth/constants.py new file mode 100644 index 000000000000..2dc2110c75c2 --- /dev/null +++ b/src/custom_auth/constants.py @@ -0,0 +1,3 @@ +USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE = ( + "User registration without an invite is disabled for this installation." +) diff --git a/src/custom_auth/oauth/serializers.py b/src/custom_auth/oauth/serializers.py index 341e56f14398..1b317f7de96b 100644 --- a/src/custom_auth/oauth/serializers.py +++ b/src/custom_auth/oauth/serializers.py @@ -1,9 +1,14 @@ +from django.conf import settings from django.contrib.auth import get_user_model from rest_framework import serializers from rest_framework.authtoken.models import Token +from rest_framework.exceptions import PermissionDenied -from custom_auth.oauth.github import GithubUser -from custom_auth.oauth.google import get_user_info +from organisations.invites.models import Invite + +from ..constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE +from .github import GithubUser +from .google import get_user_info GOOGLE_URL = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json&" UserModel = get_user_model() @@ -19,11 +24,25 @@ class Meta: abstract = True def create(self, validated_data): - user_data = self.get_user_info() - email = user_data.pop("email") - user, _ = UserModel.objects.get_or_create(email=email, defaults=user_data) + user = self._get_user() return Token.objects.get_or_create(user=user)[0] + def _get_user(self): + user_data = self.get_user_info() + email = user_data.get("email") + existing_user = UserModel.objects.filter(email=email).first() + + if not existing_user: + if not ( + settings.ALLOW_REGISTRATION_WITHOUT_INVITE + or Invite.objects.filter(email=email).exists() + ): + raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) + + return UserModel.objects.create(**user_data) + + return existing_user + def get_user_info(self): raise NotImplementedError("`get_user_info()` must be implemented.") diff --git a/src/custom_auth/oauth/tests/test_oauth_views.py b/src/custom_auth/oauth/tests/test_oauth_views.py new file mode 100644 index 000000000000..1ca6dbf027bf --- /dev/null +++ b/src/custom_auth/oauth/tests/test_oauth_views.py @@ -0,0 +1,137 @@ +from unittest import mock + +from django.test import override_settings +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from organisations.invites.models import Invite +from organisations.models import Organisation + + +@mock.patch("custom_auth.oauth.serializers.get_user_info") +@override_settings(ALLOW_REGISTRATION_WITHOUT_INVITE=False) +def test_cannot_register_with_google_without_invite_if_registration_disabled( + mock_get_user_info, db +): + # Given + url = reverse(f"api-v1:custom_auth:oauth:google-oauth-login") + client = APIClient() + + email = "test@example.com" + mock_get_user_info.return_value = {"email": email} + + # When + response = client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@mock.patch("custom_auth.oauth.serializers.GithubUser") +@override_settings(ALLOW_REGISTRATION_WITHOUT_INVITE=False) +def test_cannot_register_with_github_without_invite_if_registration_disabled( + MockGithubUser, db +): + # Given + url = reverse(f"api-v1:custom_auth:oauth:github-oauth-login") + client = APIClient() + + email = "test@example.com" + mock_github_user = mock.MagicMock() + MockGithubUser.return_value = mock_github_user + mock_github_user.get_user_info.return_value = {"email": email} + + # When + response = client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@mock.patch("custom_auth.oauth.serializers.get_user_info") +@override_settings(ALLOW_REGISTRATION_WITHOUT_INVITE=False) +def test_can_register_with_google_with_invite_if_registration_disabled( + mock_get_user_info, db +): + # Given + url = reverse(f"api-v1:custom_auth:oauth:google-oauth-login") + client = APIClient() + + email = "test@example.com" + mock_get_user_info.return_value = {"email": email} + organisation = Organisation.objects.create(name="Test Org") + Invite.objects.create(organisation=organisation, email=email) + + # When + response = client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_200_OK + + +@mock.patch("custom_auth.oauth.serializers.GithubUser") +@override_settings(ALLOW_REGISTRATION_WITHOUT_INVITE=False) +def test_can_register_with_github_with_invite_if_registration_disabled( + MockGithubUser, db +): + # Given + url = reverse(f"api-v1:custom_auth:oauth:github-oauth-login") + client = APIClient() + + email = "test@example.com" + mock_github_user = mock.MagicMock() + MockGithubUser.return_value = mock_github_user + mock_github_user.get_user_info.return_value = {"email": email} + organisation = Organisation.objects.create(name="Test Org") + Invite.objects.create(organisation=organisation, email=email) + + # When + response = client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_200_OK + + +@mock.patch("custom_auth.oauth.serializers.get_user_info") +@override_settings(ALLOW_OAUTH_REGISTRATION_WITHOUT_INVITE=False) +def test_can_login_with_google_if_registration_disabled( + mock_get_user_info, db, django_user_model +): + # Given + url = reverse(f"api-v1:custom_auth:oauth:google-oauth-login") + client = APIClient() + + email = "test@example.com" + mock_get_user_info.return_value = {"email": email} + django_user_model.objects.create(email=email) + + # When + response = client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_200_OK + assert "key" in response.json() + + +@mock.patch("custom_auth.oauth.serializers.GithubUser") +@override_settings(ALLOW_OAUTH_REGISTRATION_WITHOUT_INVITE=False) +def test_can_login_with_github_if_registration_disabled( + MockGithubUser, db, django_user_model +): + # Given + url = reverse(f"api-v1:custom_auth:oauth:github-oauth-login") + client = APIClient() + + email = "test@example.com" + mock_github_user = mock.MagicMock() + MockGithubUser.return_value = mock_github_user + mock_github_user.get_user_info.return_value = {"email": email} + django_user_model.objects.create(email=email) + + # When + response = client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_200_OK + assert "key" in response.json() diff --git a/src/custom_auth/oauth/urls.py b/src/custom_auth/oauth/urls.py index dfb3f029c7e0..179f178c0a21 100644 --- a/src/custom_auth/oauth/urls.py +++ b/src/custom_auth/oauth/urls.py @@ -4,4 +4,7 @@ app_name = "oauth" -urlpatterns = [path("google/", login_with_google), path("github/", login_with_github)] +urlpatterns = [ + path("google/", login_with_google, name="google-oauth-login"), + path("github/", login_with_github, name="github-oauth-login"), +] diff --git a/src/custom_auth/oauth/views.py b/src/custom_auth/oauth/views.py index 94d2ea5b9f6f..e0b2b7741059 100644 --- a/src/custom_auth/oauth/views.py +++ b/src/custom_auth/oauth/views.py @@ -1,3 +1,5 @@ +import logging + from drf_yasg2.utils import swagger_auto_schema from rest_framework import status from rest_framework.decorators import api_view, permission_classes @@ -6,9 +8,11 @@ from api.serializers import ErrorSerializer from custom_auth.oauth.exceptions import GithubError, GoogleError -from custom_auth.oauth.serializers import GoogleLoginSerializer, GithubLoginSerializer +from custom_auth.oauth.serializers import ( + GithubLoginSerializer, + GoogleLoginSerializer, +) from custom_auth.serializers import CustomTokenSerializer -import logging logger = logging.getLogger(__name__) @@ -20,7 +24,7 @@ @swagger_auto_schema( method="post", request_body=GoogleLoginSerializer, - responses={200: CustomTokenSerializer, 502: ErrorSerializer}, + responses={200: CustomTokenSerializer(), 502: ErrorSerializer()}, ) @api_view(["POST"]) @permission_classes([AllowAny]) @@ -41,7 +45,7 @@ def login_with_google(request): @swagger_auto_schema( method="post", request_body=GithubLoginSerializer, - responses={200: CustomTokenSerializer, 502: ErrorSerializer}, + responses={200: CustomTokenSerializer(), 502: ErrorSerializer()}, ) @api_view(["POST"]) @permission_classes([AllowAny]) diff --git a/src/custom_auth/serializers.py b/src/custom_auth/serializers.py index 3471944da9e9..f3b735921f4e 100644 --- a/src/custom_auth/serializers.py +++ b/src/custom_auth/serializers.py @@ -1,6 +1,12 @@ +from django.conf import settings from djoser.serializers import UserCreateSerializer from rest_framework import serializers from rest_framework.authtoken.models import Token +from rest_framework.exceptions import PermissionDenied + +from organisations.invites.models import Invite + +from .constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE class CustomTokenSerializer(serializers.ModelSerializer): @@ -22,3 +28,12 @@ class Meta(UserCreateSerializer.Meta): def get_key(instance): token, _ = Token.objects.get_or_create(user=instance) return token.key + + def save(self, **kwargs): + if not ( + settings.ALLOW_REGISTRATION_WITHOUT_INVITE + or Invite.objects.filter(email=self.validated_data.get("email")) + ): + raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) + + return super(CustomUserCreateSerializer, self).save(**kwargs) diff --git a/src/custom_auth/tests/end_to_end/test_custom_auth_integration.py b/src/custom_auth/tests/end_to_end/test_custom_auth_integration.py index 37510feaa594..216c338dc0be 100644 --- a/src/custom_auth/tests/end_to_end/test_custom_auth_integration.py +++ b/src/custom_auth/tests/end_to_end/test_custom_auth_integration.py @@ -1,14 +1,17 @@ import re - import time from collections import ChainMap +from unittest import mock import pyotp from django.conf import settings from django.core import mail from django.urls import reverse from rest_framework import status -from rest_framework.test import APITestCase, override_settings +from rest_framework.test import APIClient, APITestCase, override_settings + +from organisations.invites.models import Invite +from organisations.models import Organisation from users.models import FFAdminUser @@ -16,6 +19,9 @@ class AuthIntegrationTestCase(APITestCase): test_email = "test@example.com" password = FFAdminUser.objects.make_random_password() + def setUp(self) -> None: + self.organisation = Organisation.objects.create(name="Test Organisation") + def tearDown(self) -> None: FFAdminUser.objects.all().delete() @@ -93,6 +99,41 @@ def test_register_and_login_workflows(self): assert new_login_response.status_code == status.HTTP_200_OK assert new_login_response.json()["key"] + @override_settings(ALLOW_REGISTRATION_WITHOUT_INVITE=False) + def test_cannot_register_without_invite_if_disabled(self): + # Given + register_data = { + "email": self.test_email, + "password": self.password, + "first_name": "test", + "last_name": "register", + } + + # When + url = reverse("api-v1:custom_auth:ffadminuser-list") + response = self.client.post(url, data=register_data) + + # Then + assert response.status_code == status.HTTP_403_FORBIDDEN + + @override_settings(ALLOW_REGISTRATION_WITHOUT_INVITE=False) + def test_can_register_with_invite_if_registration_disabled_without_invite(self): + # Given + register_data = { + "email": self.test_email, + "password": self.password, + "first_name": "test", + "last_name": "register", + } + Invite.objects.create(email=self.test_email, organisation=self.organisation) + + # When + url = reverse("api-v1:custom_auth:ffadminuser-list") + response = self.client.post(url, data=register_data) + + # Then + assert response.status_code == status.HTTP_201_CREATED + @override_settings( DJOSER=ChainMap( {"SEND_ACTIVATION_EMAIL": True, "SEND_CONFIRMATION_EMAIL": False},