diff --git a/api/environments/identities/traits/views.py b/api/environments/identities/traits/views.py index 201512643e65..a295c260ac1d 100644 --- a/api/environments/identities/traits/views.py +++ b/api/environments/identities/traits/views.py @@ -1,3 +1,5 @@ +from typing import Any, Set + from django.conf import settings from django.core.exceptions import BadRequest from django.db.models import Q @@ -6,6 +8,7 @@ from rest_framework import mixins, status, viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated +from rest_framework.request import Request from rest_framework.response import Response from edge_api.identities.edge_request_forwarder import ( @@ -238,32 +241,47 @@ def increment_value(self, request): return Response(serializer.data, status=200) - @swagger_auto_schema(request_body=SDKCreateUpdateTraitSerializer(many=True)) - @action(detail=False, methods=["PUT"], url_path="bulk") - def bulk_create(self, request): - try: - if not request.environment.trait_persistence_allowed(request): - raise BadRequest("Unable to set traits with client key.") + def _update_traits(self, request: Request) -> Set[str]: + + identities = {trait["identity"]["identifier"] for trait in request.data} + + existing_traits = Trait.objects.filter( + identity__identifier__in=identities, + identity__environment=request.environment, + ) + + # Map to easily access existing traits + existing_traits_map = { + (trait.identity.identifier, trait.trait_key): trait + for trait in existing_traits + } + + updated_traits = [] + delete_filter_query = Q() + + for trait in request.data: + trait_key = trait.get("trait_key") + identifier = trait["identity"]["identifier"] # endpoint allows users to delete existing traits by sending null values # for the trait value so we need to filter those out here - traits = [] - delete_filter_query = Q() - - for trait in request.data: - if trait.get("trait_value") is None: - delete_filter_query = delete_filter_query | Q( - trait_key=trait.get("trait_key"), - identity__identifier=trait["identity"]["identifier"], - identity__environment=request.environment, - ) - else: - traits.append(trait) + if trait.get("trait_value") is None: + delete_filter_query = delete_filter_query | Q( + trait_key=trait_key, + identity__identifier=identifier, + identity__environment=request.environment, + ) + continue - if delete_filter_query: - Trait.objects.filter(delete_filter_query).delete() + existing_trait = existing_traits_map.get((identifier, trait_key)) + if not existing_trait or existing_trait.trait_value != trait["trait_value"]: + updated_traits.append(trait) - serializer = self.get_serializer(data=traits, many=True) + if delete_filter_query: + Trait.objects.filter(delete_filter_query).delete() + + if len(updated_traits) > 0: + serializer = self.get_serializer(data=updated_traits, many=True) serializer.is_valid(raise_exception=True) serializer.save() @@ -273,11 +291,29 @@ def bulk_create(self, request): request.method, dict(request.headers), request.environment.project.id, - request.data, + updated_traits, ) ) - return Response(serializer.data, status=200) + return identities + + @swagger_auto_schema(request_body=SDKCreateUpdateTraitSerializer(many=True)) + @action(detail=False, methods=["PUT"], url_path="bulk") + def bulk_create(self, request): + try: + if not request.environment.trait_persistence_allowed(request): + raise BadRequest("Unable to set traits with client key.") + + identities = self._update_traits(request) + + all_traits = Trait.objects.filter( + identity__identifier__in=identities, + identity__environment=request.environment, + ) + + return Response( + SDKCreateUpdateTraitSerializer(instance=all_traits, many=True).data, + ) except (TypeError, AttributeError) as excinfo: logger.error("Invalid request data: %s" % str(excinfo))