Skip to content

Commit f7e510b

Browse files
authored
fix(db-utils): fix batch_delete function (#6283)
1 parent 4472b80 commit f7e510b

File tree

5 files changed

+149
-8
lines changed

5 files changed

+149
-8
lines changed

api/src/backend/api/db_utils.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from django.conf import settings
77
from django.contrib.auth.models import BaseUserManager
8-
from django.core.paginator import Paginator
98
from django.db import connection, models, transaction
109
from psycopg2 import connect as psycopg2_connect
1110
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
@@ -120,15 +119,18 @@ def batch_delete(queryset, batch_size=5000):
120119
total_deleted = 0
121120
deletion_summary = {}
122121

123-
paginator = Paginator(queryset.order_by("id").only("id"), batch_size)
124-
125-
for page_num in paginator.page_range:
126-
batch_ids = [obj.id for obj in paginator.page(page_num).object_list]
122+
while True:
123+
# Get a batch of IDs to delete
124+
batch_ids = set(
125+
queryset.values_list("id", flat=True).order_by("id")[:batch_size]
126+
)
127+
if not batch_ids:
128+
# No more objects to delete
129+
break
127130

128131
deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()
129132

130133
total_deleted += deleted_count
131-
132134
for model_label, count in deleted_info.items():
133135
deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count
134136

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Generated by Django 5.1.1 on 2024-12-20 13:16
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("api", "0004_rbac_missing_admin_roles"),
9+
]
10+
11+
operations = [
12+
migrations.RemoveConstraint(
13+
model_name="provider",
14+
name="unique_provider_uids",
15+
),
16+
migrations.AddConstraint(
17+
model_name="provider",
18+
constraint=models.UniqueConstraint(
19+
fields=("tenant_id", "provider", "uid", "is_deleted"),
20+
name="unique_provider_uids",
21+
),
22+
),
23+
]

api/src/backend/api/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class Meta(RowLevelSecurityProtectedModel.Meta):
271271

272272
constraints = [
273273
models.UniqueConstraint(
274-
fields=("tenant_id", "provider", "uid"),
274+
fields=("tenant_id", "provider", "uid", "is_deleted"),
275275
name="unique_provider_uids",
276276
),
277277
RowLevelSecurityConstraint(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from unittest.mock import Mock, patch
2+
3+
import pytest
4+
from conftest import get_api_tokens, get_authorization_header
5+
from django.urls import reverse
6+
from rest_framework.test import APIClient
7+
8+
from api.models import Provider
9+
10+
11+
@patch("api.v1.views.Task.objects.get")
12+
@patch("api.v1.views.delete_provider_task.delay")
13+
@pytest.mark.django_db
14+
def test_delete_provider_without_executing_task(
15+
mock_delete_task, mock_task_get, create_test_user, tenants_fixture, tasks_fixture
16+
):
17+
client = APIClient()
18+
19+
test_user = "[email protected]"
20+
test_password = "test_password"
21+
22+
prowler_task = tasks_fixture[0]
23+
task_mock = Mock()
24+
task_mock.id = prowler_task.id
25+
mock_delete_task.return_value = task_mock
26+
mock_task_get.return_value = prowler_task
27+
28+
user_creation_response = client.post(
29+
reverse("user-list"),
30+
data={
31+
"data": {
32+
"type": "users",
33+
"attributes": {
34+
"name": "test",
35+
"email": test_user,
36+
"password": test_password,
37+
},
38+
}
39+
},
40+
format="vnd.api+json",
41+
)
42+
assert user_creation_response.status_code == 201
43+
44+
access_token, _ = get_api_tokens(client, test_user, test_password)
45+
auth_headers = get_authorization_header(access_token)
46+
47+
create_provider_response = client.post(
48+
reverse("provider-list"),
49+
data={
50+
"data": {
51+
"type": "providers",
52+
"attributes": {
53+
"provider": Provider.ProviderChoices.AWS,
54+
"uid": "123456789012",
55+
},
56+
}
57+
},
58+
format="vnd.api+json",
59+
headers=auth_headers,
60+
)
61+
assert create_provider_response.status_code == 201
62+
provider_id = create_provider_response.json()["data"]["id"]
63+
provider_uid = create_provider_response.json()["data"]["attributes"]["uid"]
64+
65+
remove_provider = client.delete(
66+
reverse("provider-detail", kwargs={"pk": provider_id}),
67+
headers=auth_headers,
68+
)
69+
assert remove_provider.status_code == 202
70+
71+
recreate_provider_response = client.post(
72+
reverse("provider-list"),
73+
data={
74+
"data": {
75+
"type": "providers",
76+
"attributes": {
77+
"provider": Provider.ProviderChoices.AWS,
78+
"uid": provider_uid,
79+
},
80+
}
81+
},
82+
format="vnd.api+json",
83+
headers=auth_headers,
84+
)
85+
assert recreate_provider_response.status_code == 201

api/src/backend/api/tests/test_db_utils.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
from enum import Enum
33
from unittest.mock import patch
44

5-
from api.db_utils import enum_to_choices, one_week_from_now, generate_random_token
5+
import pytest
6+
7+
from api.db_utils import (
8+
batch_delete,
9+
enum_to_choices,
10+
generate_random_token,
11+
one_week_from_now,
12+
)
13+
from api.models import Provider
614

715

816
class TestEnumToChoices:
@@ -106,3 +114,26 @@ def test_generate_random_token_no_symbols_provided(self):
106114
token = generate_random_token(length=5, symbols="")
107115
# Default symbols
108116
assert len(token) == 5
117+
118+
119+
class TestBatchDelete:
120+
@pytest.fixture
121+
def create_test_providers(self, tenants_fixture):
122+
tenant = tenants_fixture[0]
123+
provider_id = 123456789012
124+
provider_count = 10
125+
for i in range(provider_count):
126+
Provider.objects.create(
127+
tenant=tenant,
128+
uid=f"{provider_id + i}",
129+
provider=Provider.ProviderChoices.AWS,
130+
)
131+
return provider_count
132+
133+
@pytest.mark.django_db
134+
def test_batch_delete(self, create_test_providers):
135+
_, summary = batch_delete(
136+
Provider.objects.all(), batch_size=create_test_providers // 2
137+
)
138+
assert Provider.objects.all().count() == 0
139+
assert summary == {"api.Provider": create_test_providers}

0 commit comments

Comments
 (0)