Skip to content

Commit a575483

Browse files
committed
Revert "Refactor Sqlalchemy queries to 2.0 style (Part 3) (apache#32177)"
This reverts commit 1065687.
1 parent 1c1dbd8 commit a575483

File tree

3 files changed

+244
-265
lines changed

3 files changed

+244
-265
lines changed

airflow/utils/db.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _format_airflow_moved_table_name(source_table, version, category):
9292
@provide_session
9393
def merge_conn(conn, session: Session = NEW_SESSION):
9494
"""Add new Connection."""
95-
if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)):
95+
if not session.query(conn.__class__).filter_by(conn_id=conn.conn_id).first():
9696
session.add(conn)
9797
session.commit()
9898

@@ -959,9 +959,7 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
959959

960960
dups = []
961961
try:
962-
dups = session.execute(
963-
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
964-
).all()
962+
dups = session.query(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1).all()
965963
except (exc.OperationalError, exc.ProgrammingError):
966964
# fallback if tables hasn't been created yet
967965
session.rollback()
@@ -986,11 +984,12 @@ def check_username_duplicates(session: Session) -> Iterable[str]:
986984
for model in [User, RegisterUser]:
987985
dups = []
988986
try:
989-
dups = session.execute(
990-
select(model.username) # type: ignore[attr-defined]
987+
dups = (
988+
session.query(model.username) # type: ignore[attr-defined]
991989
.group_by(model.username) # type: ignore[attr-defined]
992990
.having(func.count() > 1)
993-
).all()
991+
.all()
992+
)
994993
except (exc.OperationalError, exc.ProgrammingError):
995994
# fallback if tables hasn't been created yet
996995
session.rollback()
@@ -1059,13 +1058,13 @@ def check_task_fail_for_duplicates(session):
10591058
"""
10601059
minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
10611060
try:
1062-
subquery = session.execute(
1063-
select(minimal_table_obj, func.count().label("dupe_count"))
1061+
subquery = (
1062+
session.query(minimal_table_obj, func.count().label("dupe_count"))
10641063
.group_by(*[text(x) for x in uniqueness])
10651064
.having(func.count() > text("1"))
10661065
.subquery()
10671066
)
1068-
dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count)))
1067+
dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar()
10691068
if not dupe_count:
10701069
# there are no duplicates; nothing to do.
10711070
return
@@ -1102,7 +1101,7 @@ def check_conn_type_null(session: Session) -> Iterable[str]:
11021101

11031102
n_nulls = []
11041103
try:
1105-
n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
1104+
n_nulls = session.query(Connection.conn_id).filter(Connection.conn_type.is_(None)).all()
11061105
except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
11071106
# fallback if tables hasn't been created yet
11081107
session.rollback()
@@ -1144,7 +1143,7 @@ def check_run_id_null(session: Session) -> Iterable[str]:
11441143
dagrun_table.c.run_id.is_(None),
11451144
dagrun_table.c.execution_date.is_(None),
11461145
)
1147-
invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter))
1146+
invalid_dagrun_count = session.query(func.count(dagrun_table.c.id)).filter(invalid_dagrun_filter).scalar()
11481147
if invalid_dagrun_count > 0:
11491148
dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling")
11501149
if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
@@ -1241,7 +1240,7 @@ def _move_dangling_data_to_new_table(
12411240
pk_cols = source_table.primary_key.columns
12421241

12431242
delete = source_table.delete().where(
1244-
tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery())
1243+
tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery())
12451244
)
12461245
else:
12471246
delete = source_table.delete().where(
@@ -1263,11 +1262,10 @@ def _dangling_against_dag_run(session, source_table, dag_run):
12631262
source_table.c.dag_id == dag_run.c.dag_id,
12641263
source_table.c.execution_date == dag_run.c.execution_date,
12651264
)
1266-
12671265
return (
1268-
select(*[c.label(c.name) for c in source_table.c])
1266+
session.query(*[c.label(c.name) for c in source_table.c])
12691267
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
1270-
.where(dag_run.c.dag_id.is_(None))
1268+
.filter(dag_run.c.dag_id.is_(None))
12711269
)
12721270

12731271

@@ -1306,10 +1304,10 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc
13061304
)
13071305

13081306
return (
1309-
select(*[c.label(c.name) for c in source_table.c])
1307+
session.query(*[c.label(c.name) for c in source_table.c])
13101308
.join(dag_run, dr_join_cond, isouter=True)
13111309
.join(task_instance, ti_join_cond, isouter=True)
1312-
.where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
1310+
.filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
13131311
)
13141312

13151313

@@ -1333,9 +1331,9 @@ def _move_duplicate_data_to_new_table(
13331331
"""
13341332
bind = session.get_bind()
13351333
dialect_name = bind.dialect.name
1336-
13371334
query = (
1338-
select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
1335+
session.query(source_table)
1336+
.with_entities(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns])
13391337
.select_from(source_table)
13401338
.join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness]))
13411339
)

airflow/www/utils.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from pygments.lexer import Lexer
4040
from sqlalchemy import delete, func, types
4141
from sqlalchemy.ext.associationproxy import AssociationProxy
42-
from sqlalchemy.sql import Select
4342

4443
from airflow.exceptions import RemovedInAirflow3Warning
4544
from airflow.models import errors
@@ -54,6 +53,7 @@
5453
from airflow.www.widgets import AirflowDateTimePickerWidget
5554

5655
if TYPE_CHECKING:
56+
from sqlalchemy.orm.query import Query
5757
from sqlalchemy.orm.session import Session
5858
from sqlalchemy.sql.operators import ColumnOperators
5959

@@ -518,21 +518,18 @@ def _get_run_ordering_expr(name: str) -> ColumnOperators:
518518
return expr.desc()
519519

520520

521-
def sorted_dag_runs(
522-
query: Select, *, ordering: Sequence[str], limit: int, session: Session
523-
) -> Sequence[DagRun]:
521+
def sorted_dag_runs(query: Query, *, ordering: Sequence[str], limit: int) -> Sequence[DagRun]:
524522
"""Produce DAG runs sorted by specified columns.
525523
526-
:param query: An ORM select object against *DagRun*.
524+
:param query: An ORM query object against *DagRun*.
527525
:param ordering: Column names to sort the runs. should generally come from a
528526
timetable's ``run_ordering``.
529527
:param limit: Number of runs to limit to.
530-
:param session: SQLAlchemy ORM session object
531528
:return: A list of DagRun objects ordered by the specified columns. The list
532529
contains only the *last* objects, but in *ascending* order.
533530
"""
534531
ordering_exprs = (_get_run_ordering_expr(name) for name in ordering)
535-
runs = session.scalars(query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit)).all()
532+
runs = query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit).all()
536533
runs.reverse()
537534
return runs
538535

0 commit comments

Comments
 (0)