Skip to content

Commit

Permalink
Refactore SqlAlchemy session.execute() calls to 2.0 style in case of …
Browse files Browse the repository at this point in the history
…plain text SQL queries
  • Loading branch information
moiseenkov committed Jul 26, 2023
1 parent a85d546 commit 8591ba4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
5 changes: 2 additions & 3 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,9 +1221,8 @@ def _create_table_as(
)
else:
# Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax
session.execute(
f"CREATE TABLE {target_table_name} AS {source_query.selectable.compile(bind=session.get_bind())}"
)
select_table = source_query.selectable.compile(bind=session.get_bind())
session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}"))


def _move_dangling_data_to_new_table(
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pendulum
import pytest
from pytest import param
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.declarative import DeclarativeMeta

Expand Down Expand Up @@ -211,7 +212,7 @@ def test__build_query(self, table_name, date_add_kwargs, expected_to_delete, ext
)
stmt = CreateTableAs(target_table_name, query.selectable)
session.execute(stmt)
res = session.execute(f"SELECT COUNT(1) FROM {target_table_name}")
res = session.execute(text(f"SELECT COUNT(1) FROM {target_table_name}"))
for row in res:
assert row[0] == expected_to_delete

Expand Down
13 changes: 7 additions & 6 deletions tests/utils/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
from kubernetes.client import models as k8s
from pytest import param
from sqlalchemy import text
from sqlalchemy.exc import StatementError

from airflow import settings
Expand Down Expand Up @@ -54,7 +55,7 @@ def setup_method(self):
# make sure NOT to run in UTC. Only postgres supports storing
# timezone information in the datetime field
if session.bind.dialect.name == "postgresql":
session.execute("SET timezone='Europe/Amsterdam'")
session.execute(text("SET timezone='Europe/Amsterdam'"))

self.session = session

Expand Down Expand Up @@ -208,17 +209,17 @@ def test_with_row_locks(

def test_prohibit_commit(self):
with prohibit_commit(self.session) as guard:
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
with pytest.raises(RuntimeError):
self.session.commit()
self.session.rollback()

self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
guard.commit()

# Check the expected_commit is reset
with pytest.raises(RuntimeError):
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
self.session.commit()

def test_prohibit_commit_specific_session_only(self):
Expand All @@ -233,12 +234,12 @@ def test_prohibit_commit_specific_session_only(self):
assert other_session is not self.session

with prohibit_commit(self.session):
self.session.execute("SELECT 1")
self.session.execute(text("SELECT 1"))
with pytest.raises(RuntimeError):
self.session.commit()
self.session.rollback()

other_session.execute("SELECT 1")
other_session.execute(text("SELECT 1"))
other_session.commit()

def teardown_method(self):
Expand Down

0 comments on commit 8591ba4

Please sign in to comment.