Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade databricks provider dependency #43272

2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@
"aiohttp>=3.9.2, <4",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.8.0",
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
"databricks-sql-connector>=3.0.0",
"mergedeep>=1.3.4",
"pandas>=1.5.3,<2.2;python_version<\"3.9\"",
"pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
Expand Down
43 changes: 22 additions & 21 deletions providers/src/airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,20 @@
)

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row

from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
)
from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.types import Row


LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")

Expand Down Expand Up @@ -103,7 +105,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
self._sql_conn = None
self._sql_conn: Connection | None = None
self._token: str | None = None
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
Expand Down Expand Up @@ -143,7 +145,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
else:
return endpoint

def get_conn(self) -> Connection:
def get_conn(self) -> AirflowConnection:
"""Return a Databricks SQL connection object."""
if not self._http_path:
if self._sql_endpoint_name:
Expand All @@ -158,20 +160,15 @@ def get_conn(self) -> Connection:
"or sql_endpoint_name should be specified"
)

requires_init = True
if not self._token:
self._token = self._get_token(raise_error=True)
else:
new_token = self._get_token(raise_error=True)
if new_token != self._token:
self._token = new_token
else:
requires_init = False
prev_token = self._token
new_token = self._get_token(raise_error=True)
if not self._token or new_token != self._token:
self._token = new_token

if not self.session_config:
self.session_config = self.databricks_conn.extra_dejson.get("session_configuration")

if not self._sql_conn or requires_init:
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
self._sql_conn = sql.connect(
Expand All @@ -186,7 +183,10 @@ def get_conn(self) -> Connection:
**self._get_extra_config(),
**self.additional_params,
)
return self._sql_conn

if self._sql_conn is None:
raise AirflowException("SQL connection is not initialized")
return cast(AirflowConnection, self._sql_conn)

@overload # type: ignore[override]
def run(
Expand Down Expand Up @@ -307,22 +307,23 @@ def run(
else:
return results

def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ...] | list[tuple[Any, ...]]:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
rows: list[Row] = result
rows: Sequence[Row] = result
if not rows:
return []
rows_fields = tuple(rows[0].__fields__)
rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore
return cast(list[tuple], [rows_object(*row) for row in rows])
else:
row: Row = result
row_fields = tuple(row.__fields__)
return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows])
elif isinstance(result, Row):
row_fields = tuple(result.__fields__)
row_object = namedtuple("Row", row_fields, rename=True) # type: ignore
return cast(tuple, row_object(*row))
return cast(tuple[Any, ...], row_object(*result))
else:
raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}")

def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
Expand Down
5 changes: 1 addition & 4 deletions providers/src/airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.20.0
- requests>=2.27.0,<3
# The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and
# it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190
# The 2.9.1 (to be released soon) already contains the fix
- databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
- databricks-sql-connector>=3.0.0
- aiohttp>=3.9.2, <4
- mergedeep>=1.3.4
- pandas>=2.1.2,<2.2;python_version>="3.9"
Expand Down
54 changes: 27 additions & 27 deletions providers/tests/snowflake/operators/test_snowflake_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@
from unittest.mock import MagicMock, patch

import pytest
from _pytest.outcomes import importorskip

from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

databricks = importorskip("databricks")

try:
from databricks.sql.types import Row
except ImportError:
# Row is used in the parametrize so it's parsed during collection and we need to have a viable
# replacement for the collection time when databricks is not installed (Python 3.12 for now)
def Row(*args, **kwargs):
return MagicMock()
class MockRow:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def __eq__(self, other):
return isinstance(other, MockRow) and self.__dict__ == other.__dict__

def __repr__(self):
return f"MockRow({self.__dict__})"


from airflow.models.connection import Connection
Expand All @@ -59,59 +59,59 @@ def Row(*args, **kwargs):
"select * from dummy",
True,
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Single SQL statement, return_last, split statement",
),
pytest.param(
"select * from dummy;select * from dummy2",
True,
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Multiple SQL statements, return_last, split statement",
),
pytest.param(
"select * from dummy",
False,
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement",
),
pytest.param(
"select * from dummy",
True,
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement",
),
pytest.param(
["select * from dummy"],
False,
False,
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
[[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]],
[[("id",), ("value",)]],
[([Row(id=1, value="value1"), Row(id=2, value="value2")])],
[([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")])],
id="Non-Scalar: Single SQL statements in list, no return_last, no split statement",
),
pytest.param(
["select * from dummy", "select * from dummy2"],
False,
False,
[
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")],
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
([MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement",
),
Expand All @@ -120,13 +120,13 @@ def Row(*args, **kwargs):
True,
False,
[
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")],
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
([MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement",
),
Expand Down