8
8
from django .utils .decorators import method_decorator
9
9
from django .views .decorators .cache import cache_control
10
10
from drf_spectacular .settings import spectacular_settings
11
- from drf_spectacular_jsonapi .schemas .openapi import JsonApiAutoSchema
12
11
from drf_spectacular .utils import (
13
12
OpenApiParameter ,
14
13
OpenApiResponse ,
17
16
extend_schema_view ,
18
17
)
19
18
from drf_spectacular .views import SpectacularAPIView
19
+ from drf_spectacular_jsonapi .schemas .openapi import JsonApiAutoSchema
20
20
from rest_framework import permissions , status
21
21
from rest_framework .decorators import action
22
22
from rest_framework .exceptions import (
26
26
ValidationError ,
27
27
)
28
28
from rest_framework .generics import GenericAPIView , get_object_or_404
29
+ from rest_framework .permissions import SAFE_METHODS
29
30
from rest_framework_json_api .views import RelationshipView , Response
30
31
from rest_framework_simplejwt .exceptions import InvalidToken , TokenError
31
- from rest_framework .permissions import SAFE_METHODS
32
-
33
32
from tasks .beat import schedule_provider_scan
34
33
from tasks .tasks import (
35
34
check_provider_connection_task ,
50
49
ProviderGroupFilter ,
51
50
ProviderSecretFilter ,
52
51
ResourceFilter ,
52
+ RoleFilter ,
53
53
ScanFilter ,
54
54
ScanSummaryFilter ,
55
55
TaskFilter ,
56
56
TenantFilter ,
57
57
UserFilter ,
58
- RoleFilter ,
59
58
)
60
59
from api .models import (
61
- StatusChoices ,
62
- User ,
63
- UserRoleRelationship ,
64
60
ComplianceOverview ,
65
61
Finding ,
66
62
Invitation ,
69
65
ProviderGroup ,
70
66
ProviderGroupMembership ,
71
67
ProviderSecret ,
68
+ Resource ,
72
69
Role ,
73
70
RoleProviderGroupRelationship ,
74
- Resource ,
75
71
Scan ,
76
72
ScanSummary ,
77
73
SeverityChoices ,
78
74
StateChoices ,
75
+ StatusChoices ,
79
76
Task ,
77
+ User ,
78
+ UserRoleRelationship ,
80
79
)
81
80
from api .pagination import ComplianceOverviewPagination
82
81
from api .rbac .permissions import HasPermissions , Permissions
83
82
from api .rls import Tenant
84
83
from api .utils import validate_invitation
85
84
from api .uuid_utils import datetime_to_uuid7
86
85
from api .v1 .serializers import (
87
- TokenSerializer ,
88
- TokenRefreshSerializer ,
89
- UserSerializer ,
90
- UserCreateSerializer ,
91
- UserUpdateSerializer ,
92
- UserRoleRelationshipSerializer ,
93
86
ComplianceOverviewFullSerializer ,
94
87
ComplianceOverviewSerializer ,
95
88
FindingDynamicFilterSerializer ,
106
99
ProviderGroupMembershipSerializer ,
107
100
ProviderGroupSerializer ,
108
101
ProviderGroupUpdateSerializer ,
109
- RoleProviderGroupRelationshipSerializer ,
102
+ ProviderSecretCreateSerializer ,
103
+ ProviderSecretSerializer ,
104
+ ProviderSecretUpdateSerializer ,
110
105
ProviderSerializer ,
111
106
ProviderUpdateSerializer ,
112
- TenantSerializer ,
113
- TaskSerializer ,
114
- ScanSerializer ,
115
- ScanCreateSerializer ,
116
- ScanUpdateSerializer ,
117
107
ResourceSerializer ,
118
- ProviderSecretSerializer ,
119
- ProviderSecretUpdateSerializer ,
120
- ProviderSecretCreateSerializer ,
121
- RoleSerializer ,
122
108
RoleCreateSerializer ,
109
+ RoleProviderGroupRelationshipSerializer ,
110
+ RoleSerializer ,
123
111
RoleUpdateSerializer ,
112
+ ScanCreateSerializer ,
113
+ ScanSerializer ,
114
+ ScanUpdateSerializer ,
124
115
ScheduleDailyCreateSerializer ,
116
+ TaskSerializer ,
117
+ TenantSerializer ,
118
+ TokenRefreshSerializer ,
119
+ TokenSerializer ,
120
+ UserCreateSerializer ,
121
+ UserRoleRelationshipSerializer ,
122
+ UserSerializer ,
123
+ UserUpdateSerializer ,
125
124
)
126
125
127
-
128
126
CACHE_DECORATOR = cache_control (
129
127
max_age = django_settings .CACHE_MAX_AGE ,
130
128
stale_while_revalidate = django_settings .CACHE_STALE_WHILE_REVALIDATE ,
@@ -456,7 +454,7 @@ class UserRoleRelationshipView(RelationshipView, BaseRLSViewSet):
456
454
schema = RelationshipViewSchema ()
457
455
458
456
def get_queryset (self ):
459
- return User .objects .all ( )
457
+ return User .objects .filter ( tenant_id = self . request . tenant_id )
460
458
461
459
def create (self , request , * args , ** kwargs ):
462
460
user = self .get_object ()
@@ -740,7 +738,7 @@ def get_required_permissions(self):
740
738
741
739
def get_queryset (self ):
742
740
user = self .request .user
743
- user_roles = user .roles .all ( )
741
+ user_roles = user .roles .filter ( tenant_id = self . request . tenant_id )
744
742
745
743
# Check if any of the user's roles have UNLIMITED_VISIBILITY
746
744
if getattr (user_roles [0 ], Permissions .UNLIMITED_VISIBILITY .value , False ):
@@ -801,7 +799,7 @@ class ProviderGroupProvidersRelationshipView(RelationshipView, BaseRLSViewSet):
801
799
schema = RelationshipViewSchema ()
802
800
803
801
def get_queryset (self ):
804
- return ProviderGroup .objects .all ( )
802
+ return ProviderGroup .objects .filter ( tenant_id = self . request . tenant_id )
805
803
806
804
def create (self , request , * args , ** kwargs ):
807
805
provider_group = self .get_object ()
@@ -921,14 +919,15 @@ def get_required_permissions(self):
921
919
def get_queryset (self ):
922
920
user = self .request .user
923
921
user_roles = user .roles .all ()
922
+ tenant_id = self .request .tenant_id
924
923
if getattr (user_roles [0 ], Permissions .UNLIMITED_VISIBILITY .value , False ):
925
924
# User has unlimited visibility, return all providers
926
- return Provider .objects .all ( )
925
+ return Provider .objects .filter ( tenant_id = tenant_id )
927
926
928
927
# User lacks permission, filter providers based on provider groups associated with the role
929
928
provider_groups = user_roles [0 ].provider_groups .all ()
930
929
providers = Provider .objects .filter (
931
- provider_groups__in = provider_groups
930
+ provider_groups__in = provider_groups , tenant_id = tenant_id
932
931
).distinct ()
933
932
934
933
return providers
@@ -1075,14 +1074,15 @@ def get_required_permissions(self):
1075
1074
def get_queryset (self ):
1076
1075
user = self .request .user
1077
1076
user_roles = user .roles .all ()
1077
+ tenant_id = self .request .tenant_id
1078
1078
if getattr (user_roles [0 ], Permissions .UNLIMITED_VISIBILITY .value , False ):
1079
1079
# User has unlimited visibility, return all scans
1080
- return Scan .objects .all ( )
1080
+ return Scan .objects .filter ( tenant_id = tenant_id )
1081
1081
1082
1082
# User lacks permission, filter providers based on provider groups associated with the role
1083
1083
provider_groups = user_roles [0 ].provider_groups .all ()
1084
1084
providers = Provider .objects .filter (
1085
- provider_groups__in = provider_groups
1085
+ provider_groups__in = provider_groups , tenant_id = tenant_id
1086
1086
).distinct ()
1087
1087
return Scan .objects .filter (provider__in = providers ).distinct ()
1088
1088
@@ -1180,6 +1180,7 @@ class TaskViewSet(BaseRLSViewSet):
1180
1180
def get_queryset (self ):
1181
1181
user = self .request .user
1182
1182
user_roles = user .roles .all ()
1183
+ tenant_id = self .request .tenant_id
1183
1184
if getattr (user_roles [0 ], Permissions .UNLIMITED_VISIBILITY .value , False ):
1184
1185
# User has unlimited visibility, return all tasks
1185
1186
return Task .objects .annotate (
@@ -1190,9 +1191,11 @@ def get_queryset(self):
1190
1191
# User lacks permission, filter tasks based on provider groups associated with the role
1191
1192
provider_groups = user_roles [0 ].provider_groups .all ()
1192
1193
providers = Provider .objects .filter (
1193
- provider_groups__in = provider_groups
1194
+ provider_groups__in = provider_groups , tenant_id = tenant_id
1195
+ ).distinct ()
1196
+ scans = Scan .objects .filter (
1197
+ provider__in = providers , tenant_id = tenant_id
1194
1198
).distinct ()
1195
- scans = Scan .objects .filter (provider__in = providers ).distinct ()
1196
1199
return Task .objects .filter (scan__in = scans ).distinct ()
1197
1200
1198
1201
def destroy (self , request , * args , pk = None , ** kwargs ):
@@ -1267,16 +1270,19 @@ def initial(self, request, *args, **kwargs):
1267
1270
def get_queryset (self ):
1268
1271
user = self .request .user
1269
1272
user_roles = user .roles .all ()
1273
+ tenant_id = self .request .tenant_id
1270
1274
if getattr (user_roles [0 ], Permissions .UNLIMITED_VISIBILITY .value , False ):
1271
1275
# User has unlimited visibility, return all scans
1272
- queryset = Resource .objects .all ()
1276
+ queryset = Resource .objects .all (). filter ( tenant_id = tenant_id )
1273
1277
else :
1274
1278
# User lacks permission, filter providers based on provider groups associated with the role
1275
1279
provider_groups = user_roles [0 ].provider_groups .all ()
1276
1280
providers = Provider .objects .filter (
1277
- provider_groups__in = provider_groups
1281
+ provider_groups__in = provider_groups , tenant_id = tenant_id
1282
+ ).distinct ()
1283
+ queryset = Resource .objects .filter (
1284
+ provider__in = providers , tenant_id = tenant_id
1278
1285
).distinct ()
1279
- queryset = Resource .objects .filter (provider__in = providers ).distinct ()
1280
1286
1281
1287
search_value = self .request .query_params .get ("filter[search]" , None )
1282
1288
if search_value :
@@ -1368,17 +1374,22 @@ def get_serializer_class(self):
1368
1374
def get_queryset (self ):
1369
1375
user = self .request .user
1370
1376
user_roles = user .roles .all ()
1377
+ tenant_id = self .request .tenant_id
1371
1378
if getattr (user_roles [0 ], Permissions .UNLIMITED_VISIBILITY .value , False ):
1372
1379
# User has unlimited visibility, return all scans
1373
- queryset = Finding .objects .all ()
1380
+ queryset = Finding .objects .all (). filter ( tenant_id = tenant_id )
1374
1381
else :
1375
1382
# User lacks permission, filter providers based on provider groups associated with the role
1376
1383
provider_groups = user_roles [0 ].provider_groups .all ()
1377
1384
providers = Provider .objects .filter (
1378
- provider_groups__in = provider_groups
1385
+ provider_groups__in = provider_groups , tenant_id = tenant_id
1386
+ ).distinct ()
1387
+ scans = Scan .objects .filter (
1388
+ provider__in = providers , tenant_id = tenant_id
1389
+ ).distinct ()
1390
+ queryset = Finding .objects .filter (
1391
+ scan__in = scans , tenant_id = tenant_id
1379
1392
).distinct ()
1380
- scans = Scan .objects .filter (provider__in = providers ).distinct ()
1381
- queryset = Finding .objects .filter (scan__in = scans ).distinct ()
1382
1393
1383
1394
search_value = self .request .query_params .get ("filter[search]" , None )
1384
1395
if search_value :
@@ -1478,7 +1489,7 @@ class ProviderSecretViewSet(BaseRLSViewSet):
1478
1489
]
1479
1490
1480
1491
def get_queryset (self ):
1481
- return ProviderSecret .objects .all ()
1492
+ return ProviderSecret .objects .all (). filter ( tenant_id = self . request . tenant_id )
1482
1493
1483
1494
def get_serializer_class (self ):
1484
1495
if self .action == "create" :
@@ -1537,7 +1548,7 @@ class InvitationViewSet(BaseRLSViewSet):
1537
1548
permission_classes = BaseRLSViewSet .permission_classes + [HasPermissions ]
1538
1549
1539
1550
def get_queryset (self ):
1540
- return Invitation .objects .all ()
1551
+ return Invitation .objects .all (). filter ( tenant_id = self . request . tenant_id )
1541
1552
1542
1553
def get_serializer_class (self ):
1543
1554
if self .action == "create" :
@@ -1584,7 +1595,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
1584
1595
http_method_names = ["post" ]
1585
1596
1586
1597
def get_queryset (self ):
1587
- return Invitation .objects .all ()
1598
+ return Invitation .objects .all (). filter ( tenant_id = self . request . tenant_id )
1588
1599
1589
1600
def get_serializer_class (self ):
1590
1601
if hasattr (self , "response_serializer_class" ):
@@ -1676,7 +1687,7 @@ class RoleViewSet(BaseRLSViewSet):
1676
1687
permission_classes = BaseRLSViewSet .permission_classes + [HasPermissions ]
1677
1688
1678
1689
def get_queryset (self ):
1679
- return Role .objects .all ()
1690
+ return Role .objects .all (). filter ( tenant_id = self . request . tenant_id )
1680
1691
1681
1692
def get_serializer_class (self ):
1682
1693
if self .action == "create" :
@@ -1735,7 +1746,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
1735
1746
schema = RelationshipViewSchema ()
1736
1747
1737
1748
def get_queryset (self ):
1738
- return Role .objects .all ()
1749
+ return Role .objects .all (). filter ( tenant_id = self . request . tenant_id )
1739
1750
1740
1751
def create (self , request , * args , ** kwargs ):
1741
1752
role = self .get_object ()
@@ -1821,9 +1832,13 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):
1821
1832
1822
1833
def get_queryset (self ):
1823
1834
if self .action == "retrieve" :
1824
- return ComplianceOverview .objects .all ()
1835
+ return ComplianceOverview .objects .all ().filter (
1836
+ tenant_id = self .request .tenant_id
1837
+ )
1825
1838
1826
- base_queryset = self .filter_queryset (ComplianceOverview .objects .all ())
1839
+ base_queryset = self .filter_queryset (
1840
+ ComplianceOverview .objects .all ().filter (tenant_id = self .request .tenant_id )
1841
+ )
1827
1842
1828
1843
max_failed_ids = (
1829
1844
base_queryset .filter (compliance_id = OuterRef ("compliance_id" ))
@@ -1897,11 +1912,11 @@ class OverviewViewSet(BaseRLSViewSet):
1897
1912
1898
1913
def get_queryset (self ):
1899
1914
if self .action == "providers" :
1900
- return Finding .objects .all ()
1915
+ return Finding .objects .all (). filter ( tenant_id = self . request . tenant_id )
1901
1916
elif self .action == "findings" :
1902
- return ScanSummary .objects .all ()
1917
+ return ScanSummary .objects .all (). filter ( tenant_id = self . request . tenant_id )
1903
1918
elif self .action == "findings_severity" :
1904
- return ScanSummary .objects .all ()
1919
+ return ScanSummary .objects .all (). filter ( tenant_id = self . request . tenant_id )
1905
1920
else :
1906
1921
return super ().get_queryset ()
1907
1922
0 commit comments