Skip to content

Commit 5942978

Browse files
committed
chore(rls): Add tenant_id filters in views
1 parent 57854f2 commit 5942978

File tree

1 file changed

+66
-51
lines changed

1 file changed

+66
-51
lines changed

api/src/backend/api/v1/views.py

+66-51
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from django.utils.decorators import method_decorator
99
from django.views.decorators.cache import cache_control
1010
from drf_spectacular.settings import spectacular_settings
11-
from drf_spectacular_jsonapi.schemas.openapi import JsonApiAutoSchema
1211
from drf_spectacular.utils import (
1312
OpenApiParameter,
1413
OpenApiResponse,
@@ -17,6 +16,7 @@
1716
extend_schema_view,
1817
)
1918
from drf_spectacular.views import SpectacularAPIView
19+
from drf_spectacular_jsonapi.schemas.openapi import JsonApiAutoSchema
2020
from rest_framework import permissions, status
2121
from rest_framework.decorators import action
2222
from rest_framework.exceptions import (
@@ -26,10 +26,9 @@
2626
ValidationError,
2727
)
2828
from rest_framework.generics import GenericAPIView, get_object_or_404
29+
from rest_framework.permissions import SAFE_METHODS
2930
from rest_framework_json_api.views import RelationshipView, Response
3031
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
31-
from rest_framework.permissions import SAFE_METHODS
32-
3332
from tasks.beat import schedule_provider_scan
3433
from tasks.tasks import (
3534
check_provider_connection_task,
@@ -50,17 +49,14 @@
5049
ProviderGroupFilter,
5150
ProviderSecretFilter,
5251
ResourceFilter,
52+
RoleFilter,
5353
ScanFilter,
5454
ScanSummaryFilter,
5555
TaskFilter,
5656
TenantFilter,
5757
UserFilter,
58-
RoleFilter,
5958
)
6059
from api.models import (
61-
StatusChoices,
62-
User,
63-
UserRoleRelationship,
6460
ComplianceOverview,
6561
Finding,
6662
Invitation,
@@ -69,27 +65,24 @@
6965
ProviderGroup,
7066
ProviderGroupMembership,
7167
ProviderSecret,
68+
Resource,
7269
Role,
7370
RoleProviderGroupRelationship,
74-
Resource,
7571
Scan,
7672
ScanSummary,
7773
SeverityChoices,
7874
StateChoices,
75+
StatusChoices,
7976
Task,
77+
User,
78+
UserRoleRelationship,
8079
)
8180
from api.pagination import ComplianceOverviewPagination
8281
from api.rbac.permissions import HasPermissions, Permissions
8382
from api.rls import Tenant
8483
from api.utils import validate_invitation
8584
from api.uuid_utils import datetime_to_uuid7
8685
from api.v1.serializers import (
87-
TokenSerializer,
88-
TokenRefreshSerializer,
89-
UserSerializer,
90-
UserCreateSerializer,
91-
UserUpdateSerializer,
92-
UserRoleRelationshipSerializer,
9386
ComplianceOverviewFullSerializer,
9487
ComplianceOverviewSerializer,
9588
FindingDynamicFilterSerializer,
@@ -106,25 +99,30 @@
10699
ProviderGroupMembershipSerializer,
107100
ProviderGroupSerializer,
108101
ProviderGroupUpdateSerializer,
109-
RoleProviderGroupRelationshipSerializer,
102+
ProviderSecretCreateSerializer,
103+
ProviderSecretSerializer,
104+
ProviderSecretUpdateSerializer,
110105
ProviderSerializer,
111106
ProviderUpdateSerializer,
112-
TenantSerializer,
113-
TaskSerializer,
114-
ScanSerializer,
115-
ScanCreateSerializer,
116-
ScanUpdateSerializer,
117107
ResourceSerializer,
118-
ProviderSecretSerializer,
119-
ProviderSecretUpdateSerializer,
120-
ProviderSecretCreateSerializer,
121-
RoleSerializer,
122108
RoleCreateSerializer,
109+
RoleProviderGroupRelationshipSerializer,
110+
RoleSerializer,
123111
RoleUpdateSerializer,
112+
ScanCreateSerializer,
113+
ScanSerializer,
114+
ScanUpdateSerializer,
124115
ScheduleDailyCreateSerializer,
116+
TaskSerializer,
117+
TenantSerializer,
118+
TokenRefreshSerializer,
119+
TokenSerializer,
120+
UserCreateSerializer,
121+
UserRoleRelationshipSerializer,
122+
UserSerializer,
123+
UserUpdateSerializer,
125124
)
126125

127-
128126
CACHE_DECORATOR = cache_control(
129127
max_age=django_settings.CACHE_MAX_AGE,
130128
stale_while_revalidate=django_settings.CACHE_STALE_WHILE_REVALIDATE,
@@ -456,7 +454,7 @@ class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
456454
schema = RelationshipViewSchema()
457455

458456
def get_queryset(self):
459-
return User.objects.all()
457+
return User.objects.filter(tenant_id=self.request.tenant_id)
460458

461459
def create(self, request, *args, **kwargs):
462460
user = self.get_object()
@@ -740,7 +738,7 @@ def get_required_permissions(self):
740738

741739
def get_queryset(self):
742740
user = self.request.user
743-
user_roles = user.roles.all()
741+
user_roles = user.roles.filter(tenant_id=self.request.tenant_id)
744742

745743
# Check if any of the user's roles have UNLIMITED_VISIBILITY
746744
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
@@ -801,7 +799,7 @@ class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
801799
schema = RelationshipViewSchema()
802800

803801
def get_queryset(self):
804-
return ProviderGroup.objects.all()
802+
return ProviderGroup.objects.filter(tenant_id=self.request.tenant_id)
805803

806804
def create(self, request, *args, **kwargs):
807805
provider_group = self.get_object()
@@ -921,14 +919,15 @@ def get_required_permissions(self):
921919
def get_queryset(self):
922920
user = self.request.user
923921
user_roles = user.roles.all()
922+
tenant_id = self.request.tenant_id
924923
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
925924
# User has unlimited visibility, return all providers
926-
return Provider.objects.all()
925+
return Provider.objects.filter(tenant_id=tenant_id)
927926

928927
# User lacks permission, filter providers based on provider groups associated with the role
929928
provider_groups = user_roles[0].provider_groups.all()
930929
providers = Provider.objects.filter(
931-
provider_groups__in=provider_groups
930+
provider_groups__in=provider_groups, tenant_id=tenant_id
932931
).distinct()
933932

934933
return providers
@@ -1075,14 +1074,15 @@ def get_required_permissions(self):
10751074
def get_queryset(self):
10761075
user = self.request.user
10771076
user_roles = user.roles.all()
1077+
tenant_id = self.request.tenant_id
10781078
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
10791079
# User has unlimited visibility, return all scans
1080-
return Scan.objects.all()
1080+
return Scan.objects.filter(tenant_id=tenant_id)
10811081

10821082
# User lacks permission, filter providers based on provider groups associated with the role
10831083
provider_groups = user_roles[0].provider_groups.all()
10841084
providers = Provider.objects.filter(
1085-
provider_groups__in=provider_groups
1085+
provider_groups__in=provider_groups, tenant_id=tenant_id
10861086
).distinct()
10871087
return Scan.objects.filter(provider__in=providers).distinct()
10881088

@@ -1180,6 +1180,7 @@ class TaskViewSet(BaseRLSViewSet):
11801180
def get_queryset(self):
11811181
user = self.request.user
11821182
user_roles = user.roles.all()
1183+
tenant_id = self.request.tenant_id
11831184
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
11841185
# User has unlimited visibility, return all tasks
11851186
return Task.objects.annotate(
@@ -1190,9 +1191,11 @@ def get_queryset(self):
11901191
# User lacks permission, filter tasks based on provider groups associated with the role
11911192
provider_groups = user_roles[0].provider_groups.all()
11921193
providers = Provider.objects.filter(
1193-
provider_groups__in=provider_groups
1194+
provider_groups__in=provider_groups, tenant_id=tenant_id
1195+
).distinct()
1196+
scans = Scan.objects.filter(
1197+
provider__in=providers, tenant_id=tenant_id
11941198
).distinct()
1195-
scans = Scan.objects.filter(provider__in=providers).distinct()
11961199
return Task.objects.filter(scan__in=scans).distinct()
11971200

11981201
def destroy(self, request, *args, pk=None, **kwargs):
@@ -1267,16 +1270,19 @@ def initial(self, request, *args, **kwargs):
12671270
def get_queryset(self):
12681271
user = self.request.user
12691272
user_roles = user.roles.all()
1273+
tenant_id = self.request.tenant_id
12701274
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
12711275
# User has unlimited visibility, return all scans
1272-
queryset = Resource.objects.all()
1276+
queryset = Resource.objects.all().filter(tenant_id=tenant_id)
12731277
else:
12741278
# User lacks permission, filter providers based on provider groups associated with the role
12751279
provider_groups = user_roles[0].provider_groups.all()
12761280
providers = Provider.objects.filter(
1277-
provider_groups__in=provider_groups
1281+
provider_groups__in=provider_groups, tenant_id=tenant_id
1282+
).distinct()
1283+
queryset = Resource.objects.filter(
1284+
provider__in=providers, tenant_id=tenant_id
12781285
).distinct()
1279-
queryset = Resource.objects.filter(provider__in=providers).distinct()
12801286

12811287
search_value = self.request.query_params.get("filter[search]", None)
12821288
if search_value:
@@ -1368,17 +1374,22 @@ def get_serializer_class(self):
13681374
def get_queryset(self):
13691375
user = self.request.user
13701376
user_roles = user.roles.all()
1377+
tenant_id = self.request.tenant_id
13711378
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
13721379
# User has unlimited visibility, return all scans
1373-
queryset = Finding.objects.all()
1380+
queryset = Finding.objects.all().filter(tenant_id=tenant_id)
13741381
else:
13751382
# User lacks permission, filter providers based on provider groups associated with the role
13761383
provider_groups = user_roles[0].provider_groups.all()
13771384
providers = Provider.objects.filter(
1378-
provider_groups__in=provider_groups
1385+
provider_groups__in=provider_groups, tenant_id=tenant_id
1386+
).distinct()
1387+
scans = Scan.objects.filter(
1388+
provider__in=providers, tenant_id=tenant_id
1389+
).distinct()
1390+
queryset = Finding.objects.filter(
1391+
scan__in=scans, tenant_id=tenant_id
13791392
).distinct()
1380-
scans = Scan.objects.filter(provider__in=providers).distinct()
1381-
queryset = Finding.objects.filter(scan__in=scans).distinct()
13821393

13831394
search_value = self.request.query_params.get("filter[search]", None)
13841395
if search_value:
@@ -1478,7 +1489,7 @@ class ProviderSecretViewSet(BaseRLSViewSet):
14781489
]
14791490

14801491
def get_queryset(self):
1481-
return ProviderSecret.objects.all()
1492+
return ProviderSecret.objects.all().filter(tenant_id=self.request.tenant_id)
14821493

14831494
def get_serializer_class(self):
14841495
if self.action == "create":
@@ -1537,7 +1548,7 @@ class InvitationViewSet(BaseRLSViewSet):
15371548
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
15381549

15391550
def get_queryset(self):
1540-
return Invitation.objects.all()
1551+
return Invitation.objects.all().filter(tenant_id=self.request.tenant_id)
15411552

15421553
def get_serializer_class(self):
15431554
if self.action == "create":
@@ -1584,7 +1595,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
15841595
http_method_names = ["post"]
15851596

15861597
def get_queryset(self):
1587-
return Invitation.objects.all()
1598+
return Invitation.objects.all().filter(tenant_id=self.request.tenant_id)
15881599

15891600
def get_serializer_class(self):
15901601
if hasattr(self, "response_serializer_class"):
@@ -1676,7 +1687,7 @@ class RoleViewSet(BaseRLSViewSet):
16761687
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]
16771688

16781689
def get_queryset(self):
1679-
return Role.objects.all()
1690+
return Role.objects.all().filter(tenant_id=self.request.tenant_id)
16801691

16811692
def get_serializer_class(self):
16821693
if self.action == "create":
@@ -1735,7 +1746,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
17351746
schema = RelationshipViewSchema()
17361747

17371748
def get_queryset(self):
1738-
return Role.objects.all()
1749+
return Role.objects.all().filter(tenant_id=self.request.tenant_id)
17391750

17401751
def create(self, request, *args, **kwargs):
17411752
role = self.get_object()
@@ -1821,9 +1832,13 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
18211832

18221833
def get_queryset(self):
18231834
if self.action == "retrieve":
1824-
return ComplianceOverview.objects.all()
1835+
return ComplianceOverview.objects.all().filter(
1836+
tenant_id=self.request.tenant_id
1837+
)
18251838

1826-
base_queryset = self.filter_queryset(ComplianceOverview.objects.all())
1839+
base_queryset = self.filter_queryset(
1840+
ComplianceOverview.objects.all().filter(tenant_id=self.request.tenant_id)
1841+
)
18271842

18281843
max_failed_ids = (
18291844
base_queryset.filter(compliance_id=OuterRef("compliance_id"))
@@ -1897,11 +1912,11 @@ class OverviewViewSet(BaseRLSViewSet):
18971912

18981913
def get_queryset(self):
18991914
if self.action == "providers":
1900-
return Finding.objects.all()
1915+
return Finding.objects.all().filter(tenant_id=self.request.tenant_id)
19011916
elif self.action == "findings":
1902-
return ScanSummary.objects.all()
1917+
return ScanSummary.objects.all().filter(tenant_id=self.request.tenant_id)
19031918
elif self.action == "findings_severity":
1904-
return ScanSummary.objects.all()
1919+
return ScanSummary.objects.all().filter(tenant_id=self.request.tenant_id)
19051920
else:
19061921
return super().get_queryset()
19071922

0 commit comments

Comments
 (0)