|
39 | 39 | PrimaryKeyConstraint,
|
40 | 40 | String,
|
41 | 41 | delete,
|
| 42 | + select, |
42 | 43 | text,
|
43 | 44 | )
|
44 | 45 | from sqlalchemy.ext.associationproxy import association_proxy
|
|
71 | 72 |
|
72 | 73 | import pendulum
|
73 | 74 | from sqlalchemy.orm import Session
|
| 75 | + from sqlalchemy.sql import Select |
74 | 76 |
|
75 | 77 | from airflow.models.taskinstancekey import TaskInstanceKey
|
76 | 78 |
|
@@ -210,15 +212,15 @@ def set(
|
210 | 212 | message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
|
211 | 213 | warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
|
212 | 214 | try:
|
213 |
| - dag_run_id, run_id = ( |
214 |
| - session.query(DagRun.id, DagRun.run_id) |
215 |
| - .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) |
216 |
| - .one() |
217 |
| - ) |
| 215 | + dag_run_id, run_id = session.execute( |
| 216 | + select(DagRun.id, DagRun.run_id).where( |
| 217 | + DagRun.dag_id == dag_id, DagRun.execution_date == execution_date |
| 218 | + ) |
| 219 | + ).one() |
218 | 220 | except NoResultFound:
|
219 | 221 | raise ValueError(f"DAG run not found on DAG {dag_id!r} at {execution_date}") from None
|
220 | 222 | else:
|
221 |
| - dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar() |
| 223 | + dag_run_id = session.scalar(select(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id)) |
222 | 224 | if dag_run_id is None:
|
223 | 225 | raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
|
224 | 226 |
|
@@ -389,40 +391,99 @@ def get_one(
|
389 | 391 | raise ValueError("Exactly one of run_id or execution_date must be passed")
|
390 | 392 |
|
391 | 393 | if run_id:
|
392 |
| - query = BaseXCom.get_many( |
| 394 | + stmt = BaseXCom._get_many_statement( |
393 | 395 | run_id=run_id,
|
394 | 396 | key=key,
|
395 | 397 | task_ids=task_id,
|
396 | 398 | dag_ids=dag_id,
|
397 | 399 | map_indexes=map_index,
|
398 | 400 | include_prior_dates=include_prior_dates,
|
399 | 401 | limit=1,
|
400 |
| - session=session, |
401 | 402 | )
|
402 | 403 | elif execution_date is not None:
|
403 | 404 | message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
|
404 | 405 | warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
|
405 | 406 |
|
406 | 407 | with warnings.catch_warnings():
|
407 | 408 | warnings.simplefilter("ignore", RemovedInAirflow3Warning)
|
408 |
| - query = BaseXCom.get_many( |
| 409 | + stmt = BaseXCom._get_many_statement( |
409 | 410 | execution_date=execution_date,
|
410 | 411 | key=key,
|
411 | 412 | task_ids=task_id,
|
412 | 413 | dag_ids=dag_id,
|
413 | 414 | map_indexes=map_index,
|
414 | 415 | include_prior_dates=include_prior_dates,
|
415 | 416 | limit=1,
|
416 |
| - session=session, |
417 | 417 | )
|
418 | 418 | else:
|
419 | 419 | raise RuntimeError("Should not happen?")
|
420 | 420 |
|
421 |
| - result = query.with_entities(BaseXCom.value).first() |
| 421 | + result = session.execute(stmt.with_only_columns(BaseXCom.value)).first() |
422 | 422 | if result:
|
423 | 423 | return XCom.deserialize_value(result)
|
424 | 424 | return None
|
425 | 425 |
|
| 426 | + @staticmethod |
| 427 | + def _get_many_statement( |
| 428 | + execution_date: datetime.datetime | None = None, |
| 429 | + key: str | None = None, |
| 430 | + task_ids: str | Iterable[str] | None = None, |
| 431 | + dag_ids: str | Iterable[str] | None = None, |
| 432 | + map_indexes: int | Iterable[int] | None = None, |
| 433 | + include_prior_dates: bool = False, |
| 434 | + limit: int | None = None, |
| 435 | + *, |
| 436 | + run_id: str | None = None, |
| 437 | + ) -> Select: |
| 438 | + from airflow.models.dagrun import DagRun |
| 439 | + |
| 440 | + if not exactly_one(execution_date is not None, run_id is not None): |
| 441 | + raise ValueError( |
| 442 | + f"Exactly one of run_id or execution_date must be passed. " |
| 443 | + f"Passed execution_date={execution_date}, run_id={run_id}" |
| 444 | + ) |
| 445 | + if execution_date is not None: |
| 446 | + message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." |
| 447 | + warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) |
| 448 | + |
| 449 | + stmt = select(BaseXCom).join(BaseXCom.dag_run) |
| 450 | + |
| 451 | + if key: |
| 452 | + stmt = stmt.where(BaseXCom.key == key) |
| 453 | + |
| 454 | + if is_container(task_ids): |
| 455 | + stmt = stmt.where(BaseXCom.task_id.in_(task_ids)) |
| 456 | + elif task_ids is not None: |
| 457 | + stmt = stmt.where(BaseXCom.task_id == task_ids) |
| 458 | + |
| 459 | + if is_container(dag_ids): |
| 460 | + stmt = stmt.where(BaseXCom.dag_id.in_(dag_ids)) |
| 461 | + elif dag_ids is not None: |
| 462 | + stmt = stmt.where(BaseXCom.dag_id == dag_ids) |
| 463 | + |
| 464 | + if isinstance(map_indexes, range) and map_indexes.step == 1: |
| 465 | + stmt = stmt.where(BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop) |
| 466 | + elif is_container(map_indexes): |
| 467 | + stmt = stmt.where(BaseXCom.map_index.in_(map_indexes)) |
| 468 | + elif map_indexes is not None: |
| 469 | + stmt = stmt.where(BaseXCom.map_index == map_indexes) |
| 470 | + |
| 471 | + if include_prior_dates: |
| 472 | + if execution_date is not None: |
| 473 | + stmt = stmt.where(DagRun.execution_date <= execution_date) |
| 474 | + else: |
| 475 | + dr = select(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() |
| 476 | + stmt = stmt.where(BaseXCom.execution_date <= dr.c.execution_date) |
| 477 | + elif execution_date is not None: |
| 478 | + stmt = stmt.where(DagRun.execution_date == execution_date) |
| 479 | + else: |
| 480 | + stmt = stmt.where(BaseXCom.run_id == run_id) |
| 481 | + |
| 482 | + stmt = stmt.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc()) |
| 483 | + if limit: |
| 484 | + return stmt.limit(limit) |
| 485 | + return stmt |
| 486 | + |
426 | 487 | @overload
|
427 | 488 | @staticmethod
|
428 | 489 | def get_many(
|
@@ -498,56 +559,17 @@ def get_many(
|
498 | 559 |
|
499 | 560 | :sphinx-autoapi-skip:
|
500 | 561 | """
|
501 |
| - from airflow.models.dagrun import DagRun |
502 |
| - |
503 |
| - if not exactly_one(execution_date is not None, run_id is not None): |
504 |
| - raise ValueError( |
505 |
| - f"Exactly one of run_id or execution_date must be passed. " |
506 |
| - f"Passed execution_date={execution_date}, run_id={run_id}" |
507 |
| - ) |
508 |
| - if execution_date is not None: |
509 |
| - message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead." |
510 |
| - warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3) |
511 |
| - |
512 |
| - query = session.query(BaseXCom).join(BaseXCom.dag_run) |
513 |
| - |
514 |
| - if key: |
515 |
| - query = query.filter(BaseXCom.key == key) |
516 |
| - |
517 |
| - if is_container(task_ids): |
518 |
| - query = query.filter(BaseXCom.task_id.in_(task_ids)) |
519 |
| - elif task_ids is not None: |
520 |
| - query = query.filter(BaseXCom.task_id == task_ids) |
521 |
| - |
522 |
| - if is_container(dag_ids): |
523 |
| - query = query.filter(BaseXCom.dag_id.in_(dag_ids)) |
524 |
| - elif dag_ids is not None: |
525 |
| - query = query.filter(BaseXCom.dag_id == dag_ids) |
526 |
| - |
527 |
| - if isinstance(map_indexes, range) and map_indexes.step == 1: |
528 |
| - query = query.filter( |
529 |
| - BaseXCom.map_index >= map_indexes.start, BaseXCom.map_index < map_indexes.stop |
530 |
| - ) |
531 |
| - elif is_container(map_indexes): |
532 |
| - query = query.filter(BaseXCom.map_index.in_(map_indexes)) |
533 |
| - elif map_indexes is not None: |
534 |
| - query = query.filter(BaseXCom.map_index == map_indexes) |
535 |
| - |
536 |
| - if include_prior_dates: |
537 |
| - if execution_date is not None: |
538 |
| - query = query.filter(DagRun.execution_date <= execution_date) |
539 |
| - else: |
540 |
| - dr = session.query(DagRun.execution_date).filter(DagRun.run_id == run_id).subquery() |
541 |
| - query = query.filter(BaseXCom.execution_date <= dr.c.execution_date) |
542 |
| - elif execution_date is not None: |
543 |
| - query = query.filter(DagRun.execution_date == execution_date) |
544 |
| - else: |
545 |
| - query = query.filter(BaseXCom.run_id == run_id) |
546 |
| - |
547 |
| - query = query.order_by(DagRun.execution_date.desc(), BaseXCom.timestamp.desc()) |
548 |
| - if limit: |
549 |
| - return query.limit(limit) |
550 |
| - return query |
| 562 | + stmt = BaseXCom._get_many_statement( |
| 563 | + execution_date=execution_date, |
| 564 | + key=key, |
| 565 | + task_ids=task_ids, |
| 566 | + dag_ids=dag_ids, |
| 567 | + map_indexes=map_indexes, |
| 568 | + include_prior_dates=include_prior_dates, |
| 569 | + limit=limit, |
| 570 | + run_id=run_id, |
| 571 | + ) |
| 572 | + return session.scalars(stmt) |
551 | 573 |
|
552 | 574 | @classmethod
|
553 | 575 | @provide_session
|
@@ -640,17 +662,15 @@ def clear(
|
640 | 662 | if execution_date is not None:
|
641 | 663 | message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
|
642 | 664 | warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
|
643 |
| - run_id = ( |
644 |
| - session.query(DagRun.run_id) |
645 |
| - .filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) |
646 |
| - .scalar() |
| 665 | + run_id = session.scalar( |
| 666 | + select(DagRun.run_id).where(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date) |
647 | 667 | )
|
648 | 668 |
|
649 |
| - query = session.query(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) |
| 669 | + stmt = select(BaseXCom).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id) |
650 | 670 | if map_index is not None:
|
651 |
| - query = query.filter_by(map_index=map_index) |
| 671 | + stmt = stmt.filter_by(map_index=map_index) |
652 | 672 |
|
653 |
| - for xcom in query: |
| 673 | + for xcom in session.scalars(stmt): |
654 | 674 | # print(f"Clearing XCOM {xcom} with value {xcom.value}")
|
655 | 675 | XCom.purge(xcom, session)
|
656 | 676 | session.delete(xcom)
|
|
0 commit comments