Skip to content

Commit

Permalink
fix(RLS): enforce config security
Browse files Browse the repository at this point in the history
  • Loading branch information
jfagoagas committed Dec 9, 2024
1 parent 493fe2d commit d9df45a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 9 deletions.
8 changes: 4 additions & 4 deletions api/src/backend/api/base_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def initial(self, request, *args, **kwargs):
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute("SELECT set_config('api.tenant_id', %s, TRUE);", [tenant_id])
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

Expand All @@ -76,7 +76,7 @@ def initial(self, request, *args, **kwargs):
user_id = str(request.user.id)

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
cursor.execute("SELECT set_config('api.user_id', %s, TRUE);", [user_id])
return super().initial(request, *args, **kwargs)

# TODO: DRY this when we have time
Expand All @@ -93,7 +93,7 @@ def initial(self, request, *args, **kwargs):
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute("SELECT set_config('api.tenant_id', %s, TRUE);", [tenant_id])
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

Expand All @@ -120,6 +120,6 @@ def initial(self, request, *args, **kwargs):
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute("SELECT set_config('api.tenant_id', %s, TRUE);", [tenant_id])
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
7 changes: 6 additions & 1 deletion api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def psycopg_connection(database_alias: str):
def tenant_transaction(tenant_id: str):
with transaction.atomic():
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
# TODO
# try:
# uuid.UUID(tenant_id)
# except ValueError:
# raise ValidationError("Tenant ID must be a valid UUID")
cursor.execute("SELECT set_config('api.tenant_id', %s, TRUE);", [tenant_id])
yield cursor


Expand Down
8 changes: 6 additions & 2 deletions api/src/backend/api/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def wrapper(*args, **kwargs):
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")

# TODO
# try:
# uuid.UUID(tenant_id)
# except ValueError:
# raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute("SELECT set_config('api.tenant_id', %s, TRUE);", [tenant_id])

return func(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions api/src/backend/api/tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch, call
from unittest.mock import call, patch

import pytest

Expand All @@ -20,7 +20,7 @@ def random_func(arg):
result = random_func("test_arg", tenant_id=tenant_id)

assert (
call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
call("SELECT set_config('api.tenant_id', %s, TRUE);", [tenant_id])
in mock_cursor.execute.mock_calls
)
assert result == "test_arg"
Expand Down

0 comments on commit d9df45a

Please sign in to comment.