Skip to content

Commit afa9ead

Browse files
authored
Update code style for airflow db commands to SQLAlchemy 2.0 style (#31486)
* Update code style for `airflow db` commands to SQLAlchemy 2.0 style This commit introduces changes to the code styles of `airflow db` commands to remove 'RemovedIn20Warning' and ensure compatibility with SQLAlchemy 2.0. To see these warnings, you need to set SQLALCHEMY_WARN_20=True when using the db commands * fixup! Update code style for `airflow db` commands to SQLAlchemy 2.0 style * fixup! fixup! Update code style for `airflow db` commands to SQLAlchemy 2.0 style * Use connection instead of session.get_bind() * remove metadata.bind=bind
1 parent e86f688 commit afa9ead

File tree

4 files changed

+40
-46
lines changed

4 files changed

+40
-46
lines changed

airflow/models/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any
2121

2222
from sqlalchemy import MetaData, String
23-
from sqlalchemy.orm import declarative_base
23+
from sqlalchemy.orm import registry
2424

2525
from airflow.configuration import conf
2626

@@ -45,8 +45,9 @@ def _get_schema():
4545

4646

4747
metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
48+
mapper_registry = registry(metadata=metadata)
4849

49-
Base: Any = declarative_base(metadata=metadata)
50+
Base: Any = mapper_registry.generate_base()
5051

5152
ID_LEN = 250
5253

airflow/utils/db.py

+35-42
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949

5050
from airflow.models.base import Base
5151

52-
5352
log = logging.getLogger(__name__)
5453

5554
REVISION_HEADS_MAP = {
@@ -686,21 +685,28 @@ def create_default_connections(session: Session = NEW_SESSION):
686685
)
687686

688687

689-
def _create_db_from_orm(session):
690-
from alembic import command
688+
def _get_flask_db(sql_database_uri):
691689
from flask import Flask
692690
from flask_sqlalchemy import SQLAlchemy
693691

692+
from airflow.www.session import AirflowDatabaseSessionInterface
693+
694+
flask_app = Flask(__name__)
695+
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
696+
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
697+
db = SQLAlchemy(flask_app)
698+
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
699+
return db
700+
701+
702+
def _create_db_from_orm(session):
703+
from alembic import command
704+
694705
from airflow.models.base import Base
695706
from airflow.www.fab_security.sqla.models import Model
696-
from airflow.www.session import AirflowDatabaseSessionInterface
697707

698708
def _create_flask_session_tbl(sql_database_uri):
699-
flask_app = Flask(__name__)
700-
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
701-
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
702-
db = SQLAlchemy(flask_app)
703-
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
709+
db = _get_flask_db(sql_database_uri)
704710
db.create_all()
705711

706712
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
@@ -1004,15 +1010,16 @@ def reflect_tables(tables: list[Base | str] | None, session):
10041010
"""
10051011
import sqlalchemy.schema
10061012

1007-
metadata = sqlalchemy.schema.MetaData(session.bind)
1013+
bind = session.bind
1014+
metadata = sqlalchemy.schema.MetaData()
10081015

10091016
if tables is None:
1010-
metadata.reflect(resolve_fks=False)
1017+
metadata.reflect(bind=bind, resolve_fks=False)
10111018
else:
10121019
for tbl in tables:
10131020
try:
10141021
table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
1015-
metadata.reflect(only=[table_name], extend_existing=True, resolve_fks=False)
1022+
metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False)
10161023
except exc.InvalidRequestError:
10171024
continue
10181025
return metadata
@@ -1633,8 +1640,9 @@ def resetdb(session: Session = NEW_SESSION, skip_init: bool = False):
16331640
connection = settings.engine.connect()
16341641

16351642
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
1636-
drop_airflow_models(connection)
1637-
drop_airflow_moved_tables(session)
1643+
with connection.begin():
1644+
drop_airflow_models(connection)
1645+
drop_airflow_moved_tables(connection)
16381646

16391647
if not skip_init:
16401648
initdb(session=session)
@@ -1701,27 +1709,12 @@ def drop_airflow_models(connection):
17011709
:return: None
17021710
"""
17031711
from airflow.models.base import Base
1704-
1705-
# Drop connection and chart - those tables have been deleted and in case you
1706-
# run resetdb on schema with chart or users table will fail
1707-
chart = Table("chart", Base.metadata)
1708-
chart.drop(settings.engine, checkfirst=True)
1709-
user = Table("user", Base.metadata)
1710-
user.drop(settings.engine, checkfirst=True)
1711-
users = Table("users", Base.metadata)
1712-
users.drop(settings.engine, checkfirst=True)
1713-
dag_stats = Table("dag_stats", Base.metadata)
1714-
dag_stats.drop(settings.engine, checkfirst=True)
1715-
session = Table("session", Base.metadata)
1716-
session.drop(settings.engine, checkfirst=True)
1712+
from airflow.www.fab_security.sqla.models import Model
17171713

17181714
Base.metadata.drop_all(connection)
1719-
# we remove the Tables here so that if resetdb is run metadata does not keep the old tables.
1720-
Base.metadata.remove(session)
1721-
Base.metadata.remove(dag_stats)
1722-
Base.metadata.remove(users)
1723-
Base.metadata.remove(user)
1724-
Base.metadata.remove(chart)
1715+
Model.metadata.drop_all(connection)
1716+
db = _get_flask_db(connection.engine.url)
1717+
db.drop_all()
17251718
# alembic adds significant import time, so we import it lazily
17261719
from alembic.migration import MigrationContext
17271720

@@ -1731,11 +1724,11 @@ def drop_airflow_models(connection):
17311724
version.drop(connection)
17321725

17331726

1734-
def drop_airflow_moved_tables(session):
1727+
def drop_airflow_moved_tables(connection):
17351728
from airflow.models.base import Base
17361729
from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX
17371730

1738-
tables = set(inspect(session.get_bind()).get_table_names())
1731+
tables = set(inspect(connection).get_table_names())
17391732
to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)]
17401733
for tbl in to_delete:
17411734
tbl.drop(settings.engine, checkfirst=False)
@@ -1749,7 +1742,7 @@ def check(session: Session = NEW_SESSION):
17491742
17501743
:param session: session of the sqlalchemy
17511744
"""
1752-
session.execute("select 1 as is_alive;")
1745+
session.execute(text("select 1 as is_alive;"))
17531746
log.info("Connection successful.")
17541747

17551748

@@ -1780,23 +1773,23 @@ def create_global_lock(
17801773
dialect = conn.dialect
17811774
try:
17821775
if dialect.name == "postgresql":
1783-
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), timeout=lock_timeout)
1784-
conn.execute(text("SELECT pg_advisory_lock(:id)"), id=lock.value)
1776+
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
1777+
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
17851778
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
1786-
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), id=str(lock), timeout=lock_timeout)
1779+
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})
17871780
elif dialect.name == "mssql":
17881781
# TODO: make locking work for MSSQL
17891782
pass
17901783

17911784
yield
17921785
finally:
17931786
if dialect.name == "postgresql":
1794-
conn.execute("SET LOCK_TIMEOUT TO DEFAULT")
1795-
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), id=lock.value).fetchone()
1787+
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
1788+
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
17961789
if not unlocked:
17971790
raise RuntimeError("Error releasing DB lock!")
17981791
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
1799-
conn.execute(text("select RELEASE_LOCK(:id)"), id=str(lock))
1792+
conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
18001793
elif dialect.name == "mssql":
18011794
# TODO: make locking work for MSSQL
18021795
pass

tests/test_utils/db.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def drop_tables_with_prefix(prefix):
8282
metadata = reflect_tables(None, session)
8383
for table_name, table in metadata.tables.items():
8484
if table_name.startswith(prefix):
85-
table.drop()
85+
table.drop(session.bind)
8686

8787

8888
def clear_db_serialized_dags():

tests/utils/test_db.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_resetdb(
230230
session_mock = MagicMock()
231231
resetdb(session_mock, skip_init=skip_init)
232232
mock_drop_airflow.assert_called_once_with(mock_connect.return_value)
233-
mock_drop_moved.assert_called_once_with(session_mock)
233+
mock_drop_moved.assert_called_once_with(mock_connect.return_value)
234234
if skip_init:
235235
mock_init.assert_not_called()
236236
else:

0 commit comments

Comments
 (0)