Skip to content

Commit bbf89d9

Browse files
committed
chore: unify set config
1 parent 6017b20 commit bbf89d9

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

api/src/backend/api/base_views.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import uuid
22

3-
from db_utils import tenant_transaction
4-
from django.db import connection, transaction
3+
from django.db import transaction
54
from rest_framework import permissions
65
from rest_framework.exceptions import NotAuthenticated
76
from rest_framework.filters import SearchFilter
@@ -10,6 +9,7 @@
109
from rest_framework_json_api.views import ModelViewSet
1110
from rest_framework_simplejwt.authentication import JWTAuthentication
1211

12+
from api.db_utils import tenant_transaction
1313
from api.filters import CustomDjangoFilterBackend
1414

1515

@@ -75,8 +75,7 @@ def initial(self, request, *args, **kwargs):
7575
except ValueError:
7676
raise ValidationError("User ID must be a valid UUID")
7777

78-
with connection.cursor() as cursor:
79-
cursor.execute("SELECT set_config('api.user_id', %s, TRUE);", [user_id])
78+
with tenant_transaction(value=user_id, parameter="api.user_id"):
8079
return super().initial(request, *args, **kwargs)
8180

8281
# TODO: DRY this when we have time

api/src/backend/api/db_utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
POSTGRES_TENANT_VAR = "api.tenant_id"
2626
POSTGRES_USER_VAR = "api.user_id"
2727

28-
SET_API_TENANT_ID_QUERY = "SELECT set_config('api.tenant_id', %s::text, TRUE);"
28+
SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"
2929

3030

3131
@contextmanager
@@ -48,15 +48,15 @@ def psycopg_connection(database_alias: str):
4848

4949

5050
@contextmanager
51-
def tenant_transaction(tenant_id: str):
51+
def tenant_transaction(value: str, parameter: str = "api.tenant_id"):
5252
with transaction.atomic():
5353
with connection.cursor() as cursor:
5454
try:
55-
# just in case the tenant_id is an UUID object
56-
uuid.UUID(str(tenant_id))
55+
# just in case the tenant_id|user_id is an UUID object
56+
uuid.UUID(str(value))
5757
except ValueError:
58-
raise ValidationError("Tenant ID must be a valid UUID")
59-
cursor.execute(SET_API_TENANT_ID_QUERY, [tenant_id])
58+
raise ValidationError("Must be a valid UUID")
59+
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
6060
yield cursor
6161

6262

api/src/backend/api/decorators.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.db import connection, transaction
55
from rest_framework_json_api.serializers import ValidationError
66

7-
from api.db_utils import SET_API_TENANT_ID_QUERY
7+
from api.db_utils import SET_CONFIG_QUERY
88

99

1010
def set_tenant(func):
@@ -52,7 +52,7 @@ def wrapper(*args, **kwargs):
5252
except ValueError:
5353
raise ValidationError("Tenant ID must be a valid UUID")
5454
with connection.cursor() as cursor:
55-
cursor.execute(SET_API_TENANT_ID_QUERY, [tenant_id])
55+
cursor.execute(SET_CONFIG_QUERY, ["api.tenant_id", tenant_id])
5656

5757
return func(*args, **kwargs)
5858

0 commit comments

Comments
 (0)