Skip to content

Commit

Permalink
chore: add TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
jfagoagas committed Dec 16, 2024
1 parent 5942978 commit 15f5054
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
1 change: 1 addition & 0 deletions api/src/backend/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ class Resource(RowLevelSecurityProtectedModel):
)

def get_tags(self) -> dict:
# TODO: can we filter by tenant_id here?
return {tag.key: tag.value for tag in self.tags.all()}

def clear_tags(self):
Expand Down
15 changes: 9 additions & 6 deletions api/src/backend/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@
from rest_framework_simplejwt.tokens import RefreshToken

from api.models import (
ComplianceOverview,
Finding,
Invitation,
InvitationRoleRelationship,
Membership,
Provider,
ProviderGroup,
ProviderGroupMembership,
ProviderSecret,
Resource,
ResourceTag,
Finding,
ProviderSecret,
Invitation,
InvitationRoleRelationship,
Role,
RoleProviderGroupRelationship,
UserRoleRelationship,
ComplianceOverview,
Scan,
StateChoices,
Task,
User,
UserRoleRelationship,
)
from api.rls import Tenant

Expand Down Expand Up @@ -1154,6 +1154,7 @@ class InvitationSerializer(RLSSerializer):
Serializer for the Invitation model.
"""

# TODO: can we filter by tenant_id here?
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())

class Meta:
Expand All @@ -1173,6 +1174,7 @@ class Meta:


class InvitationBaseWriteSerializer(BaseWriteSerializer):
# TODO: can we filter by tenant_id here?
roles = serializers.ResourceRelatedField(many=True, queryset=Role.objects.all())

def validate_email(self, value):
Expand Down Expand Up @@ -1274,6 +1276,7 @@ class Meta:


class RoleSerializer(RLSSerializer, BaseWriteSerializer):
# TODO: can we filter by tenant_id here?
provider_groups = serializers.ResourceRelatedField(
many=True, queryset=ProviderGroup.objects.all()
)
Expand Down
26 changes: 12 additions & 14 deletions api/src/backend/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ def get_queryset(self):
tenant_id = self.request.tenant_id
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all scans
queryset = Resource.objects.all().filter(tenant_id=tenant_id)
queryset = Resource.objects.filter(tenant_id=tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
Expand Down Expand Up @@ -1377,7 +1377,7 @@ def get_queryset(self):
tenant_id = self.request.tenant_id
if getattr(user_roles[0], Permissions.UNLIMITED_VISIBILITY.value, False):
# User has unlimited visibility, return all scans
queryset = Finding.objects.all().filter(tenant_id=tenant_id)
queryset = Finding.objects.filter(tenant_id=tenant_id)
else:
# User lacks permission, filter providers based on provider groups associated with the role
provider_groups = user_roles[0].provider_groups.all()
Expand Down Expand Up @@ -1489,7 +1489,7 @@ class ProviderSecretViewSet(BaseRLSViewSet):
]

def get_queryset(self):
return ProviderSecret.objects.all().filter(tenant_id=self.request.tenant_id)
return ProviderSecret.objects.filter(tenant_id=self.request.tenant_id)

def get_serializer_class(self):
if self.action == "create":
Expand Down Expand Up @@ -1548,7 +1548,7 @@ class InvitationViewSet(BaseRLSViewSet):
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]

def get_queryset(self):
return Invitation.objects.all().filter(tenant_id=self.request.tenant_id)
return Invitation.objects.filter(tenant_id=self.request.tenant_id)

def get_serializer_class(self):
if self.action == "create":
Expand Down Expand Up @@ -1595,7 +1595,7 @@ class InvitationAcceptViewSet(BaseRLSViewSet):
http_method_names = ["post"]

def get_queryset(self):
return Invitation.objects.all().filter(tenant_id=self.request.tenant_id)
return Invitation.objects.filter(tenant_id=self.request.tenant_id)

def get_serializer_class(self):
if hasattr(self, "response_serializer_class"):
Expand Down Expand Up @@ -1687,7 +1687,7 @@ class RoleViewSet(BaseRLSViewSet):
permission_classes = BaseRLSViewSet.permission_classes + [HasPermissions]

def get_queryset(self):
return Role.objects.all().filter(tenant_id=self.request.tenant_id)
return Role.objects.filter(tenant_id=self.request.tenant_id)

def get_serializer_class(self):
if self.action == "create":
Expand Down Expand Up @@ -1746,7 +1746,7 @@ class RoleProviderGroupRelationshipView(RelationshipView, BaseRLSViewSet):
schema = RelationshipViewSchema()

def get_queryset(self):
return Role.objects.all().filter(tenant_id=self.request.tenant_id)
return Role.objects.filter(tenant_id=self.request.tenant_id)

def create(self, request, *args, **kwargs):
role = self.get_object()
Expand Down Expand Up @@ -1832,12 +1832,10 @@ class ComplianceOverviewViewSet(BaseRLSViewSet):

def get_queryset(self):
if self.action == "retrieve":
return ComplianceOverview.objects.all().filter(
tenant_id=self.request.tenant_id
)
return ComplianceOverview.objects.filter(tenant_id=self.request.tenant_id)

base_queryset = self.filter_queryset(
ComplianceOverview.objects.all().filter(tenant_id=self.request.tenant_id)
ComplianceOverview.objects.filter(tenant_id=self.request.tenant_id)
)

max_failed_ids = (
Expand Down Expand Up @@ -1912,11 +1910,11 @@ class OverviewViewSet(BaseRLSViewSet):

def get_queryset(self):
if self.action == "providers":
return Finding.objects.all().filter(tenant_id=self.request.tenant_id)
return Finding.objects.filter(tenant_id=self.request.tenant_id)
elif self.action == "findings":
return ScanSummary.objects.all().filter(tenant_id=self.request.tenant_id)
return ScanSummary.objects.filter(tenant_id=self.request.tenant_id)
elif self.action == "findings_severity":
return ScanSummary.objects.all().filter(tenant_id=self.request.tenant_id)
return ScanSummary.objects.filter(tenant_id=self.request.tenant_id)
else:
return super().get_queryset()

Expand Down

0 comments on commit 15f5054

Please sign in to comment.