49
49
50
50
from airflow .models .base import Base
51
51
52
-
53
52
log = logging .getLogger (__name__ )
54
53
55
54
REVISION_HEADS_MAP = {
@@ -686,21 +685,28 @@ def create_default_connections(session: Session = NEW_SESSION):
686
685
)
687
686
688
687
689
- def _create_db_from_orm (session ):
690
- from alembic import command
688
+ def _get_flask_db (sql_database_uri ):
691
689
from flask import Flask
692
690
from flask_sqlalchemy import SQLAlchemy
693
691
692
+ from airflow .www .session import AirflowDatabaseSessionInterface
693
+
694
+ flask_app = Flask (__name__ )
695
+ flask_app .config ["SQLALCHEMY_DATABASE_URI" ] = sql_database_uri
696
+ flask_app .config ["SQLALCHEMY_TRACK_MODIFICATIONS" ] = False
697
+ db = SQLAlchemy (flask_app )
698
+ AirflowDatabaseSessionInterface (app = flask_app , db = db , table = "session" , key_prefix = "" )
699
+ return db
700
+
701
+
702
+ def _create_db_from_orm (session ):
703
+ from alembic import command
704
+
694
705
from airflow .models .base import Base
695
706
from airflow .www .fab_security .sqla .models import Model
696
- from airflow .www .session import AirflowDatabaseSessionInterface
697
707
698
708
def _create_flask_session_tbl (sql_database_uri ):
699
- flask_app = Flask (__name__ )
700
- flask_app .config ["SQLALCHEMY_DATABASE_URI" ] = sql_database_uri
701
- flask_app .config ["SQLALCHEMY_TRACK_MODIFICATIONS" ] = False
702
- db = SQLAlchemy (flask_app )
703
- AirflowDatabaseSessionInterface (app = flask_app , db = db , table = "session" , key_prefix = "" )
709
+ db = _get_flask_db (sql_database_uri )
704
710
db .create_all ()
705
711
706
712
with create_global_lock (session = session , lock = DBLocks .MIGRATIONS ):
@@ -1004,15 +1010,16 @@ def reflect_tables(tables: list[Base | str] | None, session):
1004
1010
"""
1005
1011
import sqlalchemy .schema
1006
1012
1007
- metadata = sqlalchemy .schema .MetaData (session .bind )
1013
+ bind = session .bind
1014
+ metadata = sqlalchemy .schema .MetaData ()
1008
1015
1009
1016
if tables is None :
1010
- metadata .reflect (resolve_fks = False )
1017
+ metadata .reflect (bind = bind , resolve_fks = False )
1011
1018
else :
1012
1019
for tbl in tables :
1013
1020
try :
1014
1021
table_name = tbl if isinstance (tbl , str ) else tbl .__tablename__
1015
- metadata .reflect (only = [table_name ], extend_existing = True , resolve_fks = False )
1022
+ metadata .reflect (bind = bind , only = [table_name ], extend_existing = True , resolve_fks = False )
1016
1023
except exc .InvalidRequestError :
1017
1024
continue
1018
1025
return metadata
@@ -1633,8 +1640,9 @@ def resetdb(session: Session = NEW_SESSION, skip_init: bool = False):
1633
1640
connection = settings .engine .connect ()
1634
1641
1635
1642
with create_global_lock (session = session , lock = DBLocks .MIGRATIONS ):
1636
- drop_airflow_models (connection )
1637
- drop_airflow_moved_tables (session )
1643
+ with connection .begin ():
1644
+ drop_airflow_models (connection )
1645
+ drop_airflow_moved_tables (connection )
1638
1646
1639
1647
if not skip_init :
1640
1648
initdb (session = session )
@@ -1701,27 +1709,12 @@ def drop_airflow_models(connection):
1701
1709
:return: None
1702
1710
"""
1703
1711
from airflow .models .base import Base
1704
-
1705
- # Drop connection and chart - those tables have been deleted and in case you
1706
- # run resetdb on schema with chart or users table will fail
1707
- chart = Table ("chart" , Base .metadata )
1708
- chart .drop (settings .engine , checkfirst = True )
1709
- user = Table ("user" , Base .metadata )
1710
- user .drop (settings .engine , checkfirst = True )
1711
- users = Table ("users" , Base .metadata )
1712
- users .drop (settings .engine , checkfirst = True )
1713
- dag_stats = Table ("dag_stats" , Base .metadata )
1714
- dag_stats .drop (settings .engine , checkfirst = True )
1715
- session = Table ("session" , Base .metadata )
1716
- session .drop (settings .engine , checkfirst = True )
1712
+ from airflow .www .fab_security .sqla .models import Model
1717
1713
1718
1714
Base .metadata .drop_all (connection )
1719
- # we remove the Tables here so that if resetdb is run metadata does not keep the old tables.
1720
- Base .metadata .remove (session )
1721
- Base .metadata .remove (dag_stats )
1722
- Base .metadata .remove (users )
1723
- Base .metadata .remove (user )
1724
- Base .metadata .remove (chart )
1715
+ Model .metadata .drop_all (connection )
1716
+ db = _get_flask_db (connection .engine .url )
1717
+ db .drop_all ()
1725
1718
# alembic adds significant import time, so we import it lazily
1726
1719
from alembic .migration import MigrationContext
1727
1720
@@ -1731,11 +1724,11 @@ def drop_airflow_models(connection):
1731
1724
version .drop (connection )
1732
1725
1733
1726
1734
- def drop_airflow_moved_tables (session ):
1727
+ def drop_airflow_moved_tables (connection ):
1735
1728
from airflow .models .base import Base
1736
1729
from airflow .settings import AIRFLOW_MOVED_TABLE_PREFIX
1737
1730
1738
- tables = set (inspect (session . get_bind () ).get_table_names ())
1731
+ tables = set (inspect (connection ).get_table_names ())
1739
1732
to_delete = [Table (x , Base .metadata ) for x in tables if x .startswith (AIRFLOW_MOVED_TABLE_PREFIX )]
1740
1733
for tbl in to_delete :
1741
1734
tbl .drop (settings .engine , checkfirst = False )
@@ -1749,7 +1742,7 @@ def check(session: Session = NEW_SESSION):
1749
1742
1750
1743
:param session: session of the sqlalchemy
1751
1744
"""
1752
- session .execute ("select 1 as is_alive;" )
1745
+ session .execute (text ( "select 1 as is_alive;" ) )
1753
1746
log .info ("Connection successful." )
1754
1747
1755
1748
@@ -1780,23 +1773,23 @@ def create_global_lock(
1780
1773
dialect = conn .dialect
1781
1774
try :
1782
1775
if dialect .name == "postgresql" :
1783
- conn .execute (text ("SET LOCK_TIMEOUT to :timeout" ), timeout = lock_timeout )
1784
- conn .execute (text ("SELECT pg_advisory_lock(:id)" ), id = lock .value )
1776
+ conn .execute (text ("SET LOCK_TIMEOUT to :timeout" ), { " timeout" : lock_timeout } )
1777
+ conn .execute (text ("SELECT pg_advisory_lock(:id)" ), { "id" : lock .value } )
1785
1778
elif dialect .name == "mysql" and dialect .server_version_info >= (5 , 6 ):
1786
- conn .execute (text ("SELECT GET_LOCK(:id, :timeout)" ), id = str (lock ), timeout = lock_timeout )
1779
+ conn .execute (text ("SELECT GET_LOCK(:id, :timeout)" ), { "id" : str (lock ), " timeout" : lock_timeout } )
1787
1780
elif dialect .name == "mssql" :
1788
1781
# TODO: make locking work for MSSQL
1789
1782
pass
1790
1783
1791
1784
yield
1792
1785
finally :
1793
1786
if dialect .name == "postgresql" :
1794
- conn .execute ("SET LOCK_TIMEOUT TO DEFAULT" )
1795
- (unlocked ,) = conn .execute (text ("SELECT pg_advisory_unlock(:id)" ), id = lock .value ).fetchone ()
1787
+ conn .execute (text ( "SET LOCK_TIMEOUT TO DEFAULT" ) )
1788
+ (unlocked ,) = conn .execute (text ("SELECT pg_advisory_unlock(:id)" ), { "id" : lock .value } ).fetchone ()
1796
1789
if not unlocked :
1797
1790
raise RuntimeError ("Error releasing DB lock!" )
1798
1791
elif dialect .name == "mysql" and dialect .server_version_info >= (5 , 6 ):
1799
- conn .execute (text ("select RELEASE_LOCK(:id)" ), id = str (lock ))
1792
+ conn .execute (text ("select RELEASE_LOCK(:id)" ), { "id" : str (lock )} )
1800
1793
elif dialect .name == "mssql" :
1801
1794
# TODO: make locking work for MSSQL
1802
1795
pass
0 commit comments