Skip to content

Commit 0e68629

Browse files
committed
Check and drop if stale connection when saving task result in db
1 parent b15adf4 commit 0e68629

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

django_celery_results/backends/database.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from celery.result import GroupResult, allow_join_result, result_from_tuple
88
from celery.utils.log import get_logger
99
from celery.utils.serialization import b64decode, b64encode
10-
from django.db import connection, router, transaction
10+
from django.db import connection, connections, router, transaction
1111
from django.db.models.functions import Now
1212
from django.db.utils import InterfaceError
1313
from kombu.exceptions import DecodeError
@@ -120,6 +120,17 @@ def _store_result(
120120
using=None
121121
):
122122
"""Store return value and status of an executed task."""
123+
124+
# If a task has been running long, it may have exceeded
125+
# the max db age and/or the database connection
126+
# may have been ended due to being idle for too long.
127+
# As a safety, before we submit the result,
128+
# we ensure it still has a valid connection, just like
129+
# Django does after a request to ensure a
130+
# clean connection for the next request.
131+
(connections[self.TaskModel._default_manager.db]
132+
.close_if_unusable_or_obsolete())
133+
123134
content_type, content_encoding, result = self.encode_content(result)
124135

125136
meta = {
@@ -147,7 +158,6 @@ def _store_result(
147158

148159
if status == states.STARTED:
149160
task_props['date_started'] = Now()
150-
151161
self.TaskModel._default_manager.store_result(**task_props)
152162
return result
153163

t/proj/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
'PASSWORD': os.getenv('DB_POSTGRES_PASSWORD', 'postgres'),
3636
'OPTIONS': {
3737
'connect_timeout': 1000,
38+
3839
},
40+
'CONN_MAX_AGE': None,
3941
},
4042
'secondary': {
4143
'ENGINE': 'django.db.backends.postgresql',
@@ -50,6 +52,7 @@
5052
'TEST': {
5153
'MIRROR': 'default',
5254
},
55+
'CONN_MAX_AGE': None,
5356
},
5457
'read-only': {
5558
'ENGINE': 'django.db.backends.postgresql',
@@ -65,6 +68,7 @@
6568
'TEST': {
6669
'MIRROR': 'default',
6770
},
71+
'CONN_MAX_AGE': None,
6872
},
6973
}
7074
except ImportError:

t/unit/backends/test_database.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import time
23
import json
34
import pickle
45
import re
@@ -16,6 +17,7 @@
1617

1718
from django_celery_results.backends.database import DatabaseBackend
1819
from django_celery_results.models import ChordCounter, TaskResult
20+
from django.db import connections
1921

2022

2123
class SomeClass:
@@ -24,7 +26,7 @@ def __init__(self, data):
2426
self.data = data
2527

2628

27-
@pytest.mark.django_db()
29+
@pytest.mark.django_db(transaction=True)
2830
@pytest.mark.usefixtures('depends_on_current_app')
2931
class test_DatabaseBackend:
3032

@@ -550,6 +552,25 @@ def test_backend__task_result_meta_injection(self):
550552
tr = TaskResult.objects.get(task_id=tid2)
551553
assert json.loads(tr.meta) == {'key': 'value', 'children': []}
552554

555+
def test_backend__task_result_closes_stale_connection(self):
556+
tid = uuid()
557+
request = self._create_request(
558+
task_id=tid,
559+
name='my_task',
560+
args=[],
561+
kwargs={},
562+
task_protocol=1,
563+
)
564+
# simulate a stale connection by setting the close time
565+
# to the current time
566+
db_conn_wrapper = connections[self.b.TaskModel.objects.db]
567+
db_conn_wrapper.close_at = time.monotonic()
568+
current_db_connection = db_conn_wrapper.connection
569+
self.b.mark_as_done(tid, None, request=request)
570+
# Validate the connection was replaced in the process
571+
# of saving the task
572+
assert current_db_connection is not db_conn_wrapper.connection
573+
553574
def test_backend__task_result_date(self):
554575
tid2 = uuid()
555576

0 commit comments

Comments
 (0)