Skip to content

Commit d19542f

Browse files
committed
Migrate Xcom model queries to sqlalchemy v2
1 parent bca4ac3 commit d19542f

File tree

4 files changed

+117
-118
lines changed

4 files changed

+117
-118
lines changed

airflow/models/taskinstance.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -3120,19 +3120,13 @@ def xcom_pull(
31203120
session=session,
31213121
)
31223122

3123-
# NOTE: Since we're only fetching the value field and not the whole
3124-
# class, the @recreate annotation does not kick in. Therefore we need to
3125-
# call XCom.deserialize_value() manually.
3126-
31273123
# We are only pulling one single task.
31283124
if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
3129-
first = query.with_entities(
3130-
XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value
3131-
).first()
3125+
first = query.one_or_none()
31323126
if first is None: # No matching XCom at all.
31333127
return default
31343128
if map_indexes is not None or first.map_index < 0:
3135-
return XCom.deserialize_value(first)
3129+
return first.value
31363130
query = query.order_by(None).order_by(XCom.map_index.asc())
31373131
return LazyXComAccess.build_from_xcom_query(query)
31383132

airflow/models/xcom.py

+88-68
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
PrimaryKeyConstraint,
4040
String,
4141
delete,
42+
select,
4243
text,
4344
)
4445
from sqlalchemy.ext.associationproxy import association_proxy
@@ -71,6 +72,7 @@
7172

7273
import pendulum
7374
from sqlalchemy.orm import Session
75+
from sqlalchemy.sql import Select
7476

7577
from airflow.models.taskinstancekey import TaskInstanceKey
7678

@@ -210,15 +212,15 @@ def set(
210212
message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
211213
warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
212214
try:
213-
dag_run_id, run_id = (
214-
session.query(DagRun.id, DagRun.run_id)
215-
.filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
216-
.one()
217-
)
215+
dag_run_id, run_id = session.execute(
216+
select(DagRun.id, DagRun.run_id).where(
217+
DagRun.dag_id == dag_id, DagRun.execution_date == execution_date
218+
)
219+
).one()
218220
except NoResultFound:
219221
raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None
220222
else:
221-
dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
223+
dag_run_id = session.scalar(select(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id))
222224
if dag_run_id is None:
223225
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
224226

@@ -389,40 +391,99 @@ def get_one(
389391
raise ValueError("Exactly one of run_id or execution_date must be passed")
390392

391393
if run_id:
392-
query = BaseXCom.get_many(
394+
stmt = BaseXCom._get_many_statement(
393395
run_id=run_id,
394396
key=key,
395397
task_ids=task_id,
396398
dag_ids=dag_id,
397399
map_indexes=map_index,
398400
include_prior_dates=include_prior_dates,
399401
limit=1,
400-
session=session,
401402
)
402403
elif execution_date is not None:
403404
message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
404405
warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
405406

406407
with warnings.catch_warnings():
407408
warnings.simplefilter("ignore", RemovedInAirflow3Warning)
408-
query = BaseXCom.get_many(
409+
stmt = BaseXCom._get_many_statement(
409410
execution_date=execution_date,
410411
key=key,
411412
task_ids=task_id,
412413
dag_ids=dag_id,
413414
map_indexes=map_index,
414415
include_prior_dates=include_prior_dates,
415416
limit=1,
416-
session=session,
417417
)
418418
else:
419419
raise RuntimeError("Should not happen?")
420420

421-
result = query.with_entities(BaseXCom.value).first()
421+
result = session.execute(stmt.with_only_columns(BaseXCom.value)).first()
422422
if result:
423423
return XCom.deserialize_value(result)
424424
return None
425425

426+
@staticmethod
427+
def _get_many_statement(
428+
execution_date: datetime.datetime | None = None,
429+
key: str | None = None,
430+
task_ids: str | Iterable[str] | None = None,
431+
dag_ids: str | Iterable[str] | None = None,
432+
map_indexes: int | Iterable[int] | None = None,
433+
include_prior_dates: bool = False,
434+
limit: int | None = None,
435+
*,
436+
run_id: str | None = None,
437+
) -> Select:
438+
from airflow.models.dagrun import DagRun
439+
440+
if not exactly_one(execution_date is not None, run_id is not None):
441+
raise ValueError(
442+
f"Exactly one of run_id or execution_date must be passed. "
443+
f"Passed execution_date={execution_date}, run_id={run_id}"
444+
)
445+
if execution_date is not None:
446+
message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
447+
warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
448+
449+
stmt = select(BaseXCom).join(BaseXCom.dag_run)
450+
451+
if key:
452+
stmt = stmt.where(BaseXCom.key == key)
453+
454+
if is_container(task_ids):
455+
stmt = stmt.where(BaseXCom.task_id.in_(task_ids))
456+
elif task_ids is not None:
457+
stmt = stmt.where(BaseXCom.task_id == task_ids)
458+
459+
if is_container(dag_ids):
460+
stmt = stmt.where(BaseXCom.dag_id.in_(dag_ids))
461+
elif dag_ids is not None:
462+
stmt = stmt.where(BaseXCom.dag_id == dag_ids)
463+
464+
if isinstance(map_indexes, range) and map_indexes.step == 1:
465+
stmt = stmt.where(BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop)
466+
elif is_container(map_indexes):
467+
stmt = stmt.where(BaseXCom.map_index.in_(map_indexes))
468+
elif map_indexes is not None:
469+
stmt = stmt.where(BaseXCom.map_index == map_indexes)
470+
471+
if include_prior_dates:
472+
if execution_date is not None:
473+
stmt = stmt.where(DagRun.execution_date <= execution_date)
474+
else:
475+
dr = select(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
476+
stmt = stmt.where(BaseXCom.execution_date <= dr.c.execution_date)
477+
elif execution_date is not None:
478+
stmt = stmt.where(DagRun.execution_date == execution_date)
479+
else:
480+
stmt = stmt.where(BaseXCom.run_id == run_id)
481+
482+
stmt = stmt.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc())
483+
if limit:
484+
return stmt.limit(limit)
485+
return stmt
486+
426487
@overload
427488
@staticmethod
428489
def get_many(
@@ -498,56 +559,17 @@ def get_many(
498559
499560
:sphinx-autoapi-skip:
500561
"""
501-
from airflow.models.dagrun import DagRun
502-
503-
if not exactly_one(execution_date is not None, run_id is not None):
504-
raise ValueError(
505-
f"Exactly one of run_id or execution_date must be passed. "
506-
f"Passed execution_date={execution_date}, run_id={run_id}"
507-
)
508-
if execution_date is not None:
509-
message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
510-
warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
511-
512-
query = session.query(BaseXCom).join(BaseXCom.dag_run)
513-
514-
if key:
515-
query = query.filter(BaseXCom.key == key)
516-
517-
if is_container(task_ids):
518-
query = query.filter(BaseXCom.task_id.in_(task_ids))
519-
elif task_ids is not None:
520-
query = query.filter(BaseXCom.task_id == task_ids)
521-
522-
if is_container(dag_ids):
523-
query = query.filter(BaseXCom.dag_id.in_(dag_ids))
524-
elif dag_ids is not None:
525-
query = query.filter(BaseXCom.dag_id == dag_ids)
526-
527-
if isinstance(map_indexes, range) and map_indexes.step == 1:
528-
query = query.filter(
529-
BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop
530-
)
531-
elif is_container(map_indexes):
532-
query = query.filter(BaseXCom.map_index.in_(map_indexes))
533-
elif map_indexes is not None:
534-
query = query.filter(BaseXCom.map_index == map_indexes)
535-
536-
if include_prior_dates:
537-
if execution_date is not None:
538-
query = query.filter(DagRun.execution_date <= execution_date)
539-
else:
540-
dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery()
541-
query = query.filter(BaseXCom.execution_date <= dr.c.execution_date)
542-
elif execution_date is not None:
543-
query = query.filter(DagRun.execution_date == execution_date)
544-
else:
545-
query = query.filter(BaseXCom.run_id == run_id)
546-
547-
query = query.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc())
548-
if limit:
549-
return query.limit(limit)
550-
return query
562+
stmt = BaseXCom._get_many_statement(
563+
execution_date=execution_date,
564+
key=key,
565+
task_ids=task_ids,
566+
dag_ids=dag_ids,
567+
map_indexes=map_indexes,
568+
include_prior_dates=include_prior_dates,
569+
limit=limit,
570+
run_id=run_id,
571+
)
572+
return session.scalars(stmt)
551573

552574
@classmethod
553575
@provide_session
@@ -640,17 +662,15 @@ def clear(
640662
if execution_date is not None:
641663
message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
642664
warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
643-
run_id = (
644-
session.query(DagRun.run_id)
645-
.filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
646-
.scalar()
665+
run_id = session.scalar(
666+
select(DagRun.run_id).where(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
647667
)
648668

649-
query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
669+
stmt = select(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
650670
if map_index is not None:
651-
query = query.filter_by(map_index=map_index)
671+
stmt = stmt.filter_by(map_index=map_index)
652672

653-
for xcom in query:
673+
for xcom in session.scalars(stmt):
654674
# print(f"Clearing XCOM {xcom} with value {xcom.value}")
655675
XCom.purge(xcom, session)
656676
session.delete(xcom)

tests/models/test_xcom.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro
468468
task_ids="task_1",
469469
include_prior_dates=True,
470470
session=session,
471-
)
471+
).all()
472472

473473
# The retrieved XComs should be ordered by logical date, latest first.
474474
assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}]
@@ -488,7 +488,7 @@ def test_xcom_get_many_from_prior_dates_with_execution_date(
488488
task_ids="task_1",
489489
include_prior_dates=True,
490490
session=session,
491-
)
491+
).all()
492492

493493
# The retrieved XComs should be ordered by logical date, latest first.
494494
assert [x.value for x in stored_xcoms] == [{"key2": "value2"}, {"key1": "value1"}]

tests/providers/common/io/xcom/test_backend.py

+25-40
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from airflow.io.path import ObjectStoragePath
2929
from airflow.models.dagrun import DagRun
3030
from airflow.models.taskinstance import TaskInstance
31-
from airflow.models.xcom import BaseXCom, resolve_xcom_backend
31+
from airflow.models.xcom import resolve_xcom_backend
3232
from airflow.operators.empty import EmptyOperator
3333
from airflow.providers.common.io.xcom.backend import XComObjectStoreBackend
3434
from airflow.utils import timezone
@@ -151,20 +151,15 @@ def test_value_storage(self, task_instance, session):
151151
session=session,
152152
)
153153

154-
res = (
155-
XCom.get_many(
156-
key=XCOM_RETURN_KEY,
157-
dag_ids=task_instance.dag_id,
158-
task_ids=task_instance.task_id,
159-
run_id=task_instance.run_id,
160-
session=session,
161-
)
162-
.with_entities(BaseXCom.value)
163-
.first()
164-
)
154+
res = XCom.get_many(
155+
key=XCOM_RETURN_KEY,
156+
dag_ids=task_instance.dag_id,
157+
task_ids=task_instance.task_id,
158+
run_id=task_instance.run_id,
159+
session=session,
160+
).first()
165161

166-
data = BaseXCom.deserialize_value(res)
167-
p = ObjectStoragePath(self.path) / XComObjectStoreBackend._get_key(data)
162+
p = ObjectStoragePath(self.path) / XComObjectStoreBackend._get_key(res.value)
168163
assert p.exists() is True
169164

170165
value = XCom.get_value(
@@ -197,20 +192,15 @@ def test_clear(self, task_instance, session):
197192
session=session,
198193
)
199194

200-
res = (
201-
XCom.get_many(
202-
key=XCOM_RETURN_KEY,
203-
dag_ids=task_instance.dag_id,
204-
task_ids=task_instance.task_id,
205-
run_id=task_instance.run_id,
206-
session=session,
207-
)
208-
.with_entities(BaseXCom.value)
209-
.first()
210-
)
195+
res = XCom.get_many(
196+
key=XCOM_RETURN_KEY,
197+
dag_ids=task_instance.dag_id,
198+
task_ids=task_instance.task_id,
199+
run_id=task_instance.run_id,
200+
session=session,
201+
).first()
211202

212-
data = BaseXCom.deserialize_value(res)
213-
p = ObjectStoragePath(self.path) / XComObjectStoreBackend._get_key(data)
203+
p = ObjectStoragePath(self.path) / XComObjectStoreBackend._get_key(res.value)
214204
assert p.exists() is True
215205

216206
XCom.clear(
@@ -237,20 +227,15 @@ def test_compression(self, task_instance, session):
237227
session=session,
238228
)
239229

240-
res = (
241-
XCom.get_many(
242-
key=XCOM_RETURN_KEY,
243-
dag_ids=task_instance.dag_id,
244-
task_ids=task_instance.task_id,
245-
run_id=task_instance.run_id,
246-
session=session,
247-
)
248-
.with_entities(BaseXCom.value)
249-
.first()
250-
)
230+
res = XCom.get_many(
231+
key=XCOM_RETURN_KEY,
232+
dag_ids=task_instance.dag_id,
233+
task_ids=task_instance.task_id,
234+
run_id=task_instance.run_id,
235+
session=session,
236+
).first()
251237

252-
data = BaseXCom.deserialize_value(res)
253-
p = ObjectStoragePath(self.path) / XComObjectStoreBackend._get_key(data)
238+
p = ObjectStoragePath(self.path) / XComObjectStoreBackend._get_key(res.value)
254239
assert p.exists() is True
255240
assert p.suffix == ".gz"
256241

0 commit comments

Comments
 (0)