Skip to content

Commit

Permalink
Bump minimum version of databricks SDK to 3.0.0 (apache#43272)
Browse files Browse the repository at this point in the history
* update databricks-sql-connector to 3.0.0 for databricks provider

* rename Row definition for no-redef error

* fix: patches mypy errors for typing

* fix: fixes static checks and moves Row import

* fix: patches Row typing errors when databricks sql imported

* remove databricks import from snowflake test module

* introduce MockRow class for object comparison checks

* Update providers/src/airflow/providers/databricks/hooks/databricks_sql.py

---------

Co-authored-by: olharuban <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
  • Loading branch information
3 people authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent c66e893 commit 01e6e1d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 53 deletions.
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@
"aiohttp>=3.9.2, <4",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.8.0",
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
"databricks-sql-connector>=3.0.0",
"mergedeep>=1.3.4",
"pandas>=1.5.3,<2.2;python_version<\"3.9\"",
"pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
Expand Down
43 changes: 22 additions & 21 deletions providers/src/airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,20 @@
)

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row

from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
)
from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.types import Row


LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")

Expand Down Expand Up @@ -103,7 +105,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
self._sql_conn = None
self._sql_conn: Connection | None = None
self._token: str | None = None
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
Expand Down Expand Up @@ -143,7 +145,7 @@ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
else:
return endpoint

def get_conn(self) -> Connection:
def get_conn(self) -> AirflowConnection:
"""Return a Databricks SQL connection object."""
if not self._http_path:
if self._sql_endpoint_name:
Expand All @@ -158,20 +160,15 @@ def get_conn(self) -> Connection:
"or sql_endpoint_name should be specified"
)

requires_init = True
if not self._token:
self._token = self._get_token(raise_error=True)
else:
new_token = self._get_token(raise_error=True)
if new_token != self._token:
self._token = new_token
else:
requires_init = False
prev_token = self._token
new_token = self._get_token(raise_error=True)
if not self._token or new_token != self._token:
self._token = new_token

if not self.session_config:
self.session_config = self.databricks_conn.extra_dejson.get("session_configuration")

if not self._sql_conn or requires_init:
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
self._sql_conn = sql.connect(
Expand All @@ -186,7 +183,10 @@ def get_conn(self) -> Connection:
**self._get_extra_config(),
**self.additional_params,
)
return self._sql_conn

if self._sql_conn is None:
raise AirflowException("SQL connection is not initialized")
return cast(AirflowConnection, self._sql_conn)

@overload # type: ignore[override]
def run(
Expand Down Expand Up @@ -307,22 +307,23 @@ def run(
else:
return results

def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ...] | list[tuple[Any, ...]]:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
rows: list[Row] = result
rows: Sequence[Row] = result
if not rows:
return []
rows_fields = tuple(rows[0].__fields__)
rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore
return cast(list[tuple], [rows_object(*row) for row in rows])
else:
row: Row = result
row_fields = tuple(row.__fields__)
return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows])
elif isinstance(result, Row):
row_fields = tuple(result.__fields__)
row_object = namedtuple("Row", row_fields, rename=True) # type: ignore
return cast(tuple, row_object(*row))
return cast(tuple[Any, ...], row_object(*result))
else:
raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}")

def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
Expand Down
5 changes: 1 addition & 4 deletions providers/src/airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.20.0
- requests>=2.27.0,<3
# The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and
# it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190
# The 2.9.1 (to be released soon) already contains the fix
- databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0
- databricks-sql-connector>=3.0.0
- aiohttp>=3.9.2, <4
- mergedeep>=1.3.4
- pandas>=2.1.2,<2.2;python_version>="3.9"
Expand Down
54 changes: 27 additions & 27 deletions providers/tests/snowflake/operators/test_snowflake_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@
from unittest.mock import MagicMock, patch

import pytest
from _pytest.outcomes import importorskip

from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

databricks = importorskip("databricks")

try:
from databricks.sql.types import Row
except ImportError:
# Row is used in the parametrize so it's parsed during collection and we need to have a viable
# replacement for the collection time when databricks is not installed (Python 3.12 for now)
def Row(*args, **kwargs):
return MagicMock()
class MockRow:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def __eq__(self, other):
return isinstance(other, MockRow) and self.__dict__ == other.__dict__

def __repr__(self):
return f"MockRow({self.__dict__})"


from airflow.models.connection import Connection
Expand All @@ -59,59 +59,59 @@ def Row(*args, **kwargs):
"select * from dummy",
True,
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Single SQL statement, return_last, split statement",
),
pytest.param(
"select * from dummy;select * from dummy2",
True,
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Multiple SQL statements, return_last, split statement",
),
pytest.param(
"select * from dummy",
False,
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement",
),
pytest.param(
"select * from dummy",
True,
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[[("id",), ("value",)]],
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement",
),
pytest.param(
["select * from dummy"],
False,
False,
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
[[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]],
[[("id",), ("value",)]],
[([Row(id=1, value="value1"), Row(id=2, value="value2")])],
[([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")])],
id="Non-Scalar: Single SQL statements in list, no return_last, no split statement",
),
pytest.param(
["select * from dummy", "select * from dummy2"],
False,
False,
[
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")],
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
([MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement",
),
Expand All @@ -120,13 +120,13 @@ def Row(*args, **kwargs):
True,
False,
[
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[MockRow(id=1, value="value1"), MockRow(id=2, value="value2")],
[MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")],
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
([MockRow(id=1, value="value1"), MockRow(id=2, value="value2")]),
([MockRow(id2=1, value2="value1"), MockRow(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement",
),
Expand Down

0 comments on commit 01e6e1d

Please sign in to comment.