Skip to content

Commit

Permalink
Refactor Sqlalchemy queries to 2.0 style (Part 3) (#32350)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Ephraim Anierobi <[email protected]>
  • Loading branch information
phanikumv and ephraimbuddy authored Jul 5, 2023
1 parent 7722b6f commit 61f3330
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 244 deletions.
38 changes: 20 additions & 18 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _format_airflow_moved_table_name(source_table, version, category):
@provide_session
def merge_conn(conn, session: Session = NEW_SESSION):
"""Add new Connection."""
if not session.query(conn.__class__).filter_by(conn_id=conn.conn_id).first():
if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)):
session.add(conn)
session.commit()

Expand Down Expand Up @@ -959,7 +959,9 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:

dups = []
try:
dups = session.query(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1).all()
dups = session.execute(
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
).all()
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
Expand All @@ -984,12 +986,11 @@ def check_username_duplicates(session: Session) -> Iterable[str]:
for model in [User, RegisterUser]:
dups = []
try:
dups = (
session.query(model.username) # type: ignore[attr-defined]
dups = session.execute(
select(model.username) # type: ignore[attr-defined]
.group_by(model.username) # type: ignore[attr-defined]
.having(func.count() > 1)
.all()
)
).all()
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
Expand Down Expand Up @@ -1058,13 +1059,13 @@ def check_task_fail_for_duplicates(session):
"""
minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
try:
subquery = (
session.query(minimal_table_obj, func.count().label("dupe_count"))
subquery = session.execute(
select(minimal_table_obj, func.count().label("dupe_count"))
.group_by(*[text(x) for x in uniqueness])
.having(func.count() > text("1"))
.subquery()
)
dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar()
dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count)))
if not dupe_count:
# there are no duplicates; nothing to do.
return
Expand Down Expand Up @@ -1101,7 +1102,7 @@ def check_conn_type_null(session: Session) -> Iterable[str]:

n_nulls = []
try:
n_nulls = session.query(Connection.conn_id).filter(Connection.conn_type.is_(None)).all()
n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
# fallback if tables hasn't been created yet
session.rollback()
Expand Down Expand Up @@ -1143,7 +1144,7 @@ def check_run_id_null(session: Session) -> Iterable[str]:
dagrun_table.c.run_id.is_(None),
dagrun_table.c.execution_date.is_(None),
)
invalid_dagrun_count = session.query(func.count(dagrun_table.c.id)).filter(invalid_dagrun_filter).scalar()
invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter))
if invalid_dagrun_count > 0:
dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling")
if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
Expand Down Expand Up @@ -1240,7 +1241,7 @@ def _move_dangling_data_to_new_table(
pk_cols = source_table.primary_key.columns

delete = source_table.delete().where(
tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery())
tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery())
)
else:
delete = source_table.delete().where(
Expand All @@ -1262,10 +1263,11 @@ def _dangling_against_dag_run(session, source_table, dag_run):
source_table.c.dag_id == dag_run.c.dag_id,
source_table.c.execution_date == dag_run.c.execution_date,
)

return (
session.query(*[c.label(c.name) for c in source_table.c])
select(*[c.label(c.name) for c in source_table.c])
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
.filter(dag_run.c.dag_id.is_(None))
.where(dag_run.c.dag_id.is_(None))
)


Expand Down Expand Up @@ -1304,10 +1306,10 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc
)

return (
session.query(*[c.label(c.name) for c in source_table.c])
select(*[c.label(c.name) for c in source_table.c])
.join(dag_run, dr_join_cond, isouter=True)
.join(task_instance, ti_join_cond, isouter=True)
.filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
.where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
)


Expand All @@ -1331,9 +1333,9 @@ def _move_duplicate_data_to_new_table(
"""
bind = session.get_bind()
dialect_name = bind.dialect.name

query = (
session.query(source_table)
.with_entities(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
.select_from(source_table)
.join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness]))
)
Expand Down
11 changes: 7 additions & 4 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pygments.lexer import Lexer
from sqlalchemy import delete, func, types
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.sql import Select

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models import errors
Expand All @@ -53,7 +54,6 @@
from airflow.www.widgets import AirflowDateTimePickerWidget

if TYPE_CHECKING:
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.operators import ColumnOperators

Expand Down Expand Up @@ -518,18 +518,21 @@ def _get_run_ordering_expr(name: str) -> ColumnOperators:
return expr.desc()


def sorted_dag_runs(query: Query, *, ordering: Sequence[str], limit: int) -> Sequence[DagRun]:
def sorted_dag_runs(
query: Select, *, ordering: Sequence[str], limit: int, session: Session
) -> Sequence[DagRun]:
"""Produce DAG runs sorted by specified columns.
:param query: An ORM query object against *DagRun*.
:param query: An ORM select object against *DagRun*.
:param ordering: Column names to sort the runs. should generally come from a
timetable's ``run_ordering``.
:param limit: Number of runs to limit to.
:param session: SQLAlchemy ORM session object
:return: A list of DagRun objects ordered by the specified columns. The list
contains only the *last* objects, but in *ascending* order.
"""
ordering_exprs = (_get_run_ordering_expr(name) for name in ordering)
runs = query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit).all()
runs = session.scalars(query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit)).all()
runs.reverse()
return runs

Expand Down
Loading

0 comments on commit 61f3330

Please sign in to comment.