Skip to content

Commit

Permalink
Fix/98 sparkexpectations bump version to 220 (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
dannymeijer authored Nov 11, 2024
1 parent 6ca9f30 commit 76207e7
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 31 deletions.
6 changes: 6 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ hatch-install:
fi
init: hatch-install

.PHONY: sync ## hatch - Update dependencies if you changed project dependencies in pyproject.toml
.PHONY: update ## hatch - alias for sync (if you are used to poetry, thi is similar to running `poetry update`)
sync:
@hatch run dev:uv sync --all-extras
update: sync

# Code Quality
.PHONY: black black-fmt ## code quality - Use black to (re)format the codebase
black-fmt:
Expand Down
18 changes: 9 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,11 @@ box = ["boxsdk[jwt]==3.8.1"]
pandas = ["pandas>=1.3", "setuptools", "numpy<2.0.0", "pandas-stubs"]
pyspark = ["pyspark>=3.2.0", "pyarrow>13"]
pyspark_connect = ["pyspark[connect]>=3.5"]
se = ["spark-expectations>=2.1.0"]
# SFTP dependencies in to_csv line_iterator
sftp = ["paramiko>=2.6.0"]
delta = ["delta-spark>=2.2"]
excel = ["openpyxl>=3.0.0"]
# Tableau dependencies
tableau = ["tableauhyperapi>=0.0.19484", "tableauserverclient>=0.25"]
# Snowflake dependencies
snowflake = ["snowflake-connector-python>=3.12.0"]
# Development dependencies
dev = ["ruff", "mypy", "pylint", "colorama", "types-PyYAML", "types-requests"]
test = [
"chispa",
Expand Down Expand Up @@ -104,6 +99,10 @@ docs = [
"pymdown-extensions>=10.7.0",
"black",
]
se = ["spark-expectations>=2.2.1,<2.3.0"]

[tool.hatch.metadata]
allow-direct-references = true


### ~~~~~~~~~~~~~~~ ###
Expand Down Expand Up @@ -237,14 +236,15 @@ features = [
"async",
"async_http",
"box",
"delta",
"dev",
"excel",
"pandas",
"pyspark",
"se",
"sftp",
"delta",
"excel",
"snowflake",
"tableau",
"dev",
"test",
]

Expand Down Expand Up @@ -416,7 +416,7 @@ features = [
"box",
"pandas",
"pyspark",
# "se",
"se",
"sftp",
"snowflake",
"delta",
Expand Down
5 changes: 0 additions & 5 deletions src/koheesio/integrations/spark/dq/spark_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,10 @@

from pydantic import Field

import pyspark

from koheesio.spark import DataFrame
from koheesio.spark.transformations import Transformation
from koheesio.spark.writers import BatchOutputMode

if pyspark.__version__.startswith("3.5"):
raise ImportError("Spark Expectations is not supported for Spark 3.5")


class SparkExpectationsTransformation(Transformation):
"""
Expand Down
32 changes: 23 additions & 9 deletions src/koheesio/spark/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def get_spark_minor_version() -> float:
def check_if_pyspark_connect_is_supported() -> bool:
"""Check if the current version of PySpark supports the connect module"""
result = False
module_name: str = "pyspark"

if SPARK_MINOR_VERSION >= 3.5:
try:
importlib.import_module(f"{module_name}.sql.connect")
importlib.import_module("pyspark.sql.connect")
from pyspark.sql.connect.column import Column

_col: Column
Expand Down Expand Up @@ -119,9 +119,13 @@ def check_if_pyspark_connect_is_supported() -> bool:
ParseException = (CapturedParseException, ConnectParseException)
DataType = Union[SqlDataType, ConnectDataType]
DataFrameReader = Union[sql.readwriter.DataFrameReader, DataFrameReader]
DataStreamReader = Union[sql.streaming.readwriter.DataStreamReader, DataStreamReader]
DataStreamReader = Union[
sql.streaming.readwriter.DataStreamReader, DataStreamReader
]
DataFrameWriter = Union[sql.readwriter.DataFrameWriter, DataFrameWriter]
DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter]
DataStreamWriter = Union[
sql.streaming.readwriter.DataStreamWriter, DataStreamWriter
]
StreamingQuery = StreamingQuery
else:
"""Import the regular PySpark modules if the current version of PySpark does not support the connect module"""
Expand Down Expand Up @@ -156,8 +160,9 @@ def check_if_pyspark_connect_is_supported() -> bool:

def get_active_session() -> SparkSession: # type: ignore
"""Get the active Spark session"""
print("Entering get_active_session")
if check_if_pyspark_connect_is_supported():
from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession
from pyspark.sql.connect import SparkSession as _ConnectSparkSession

session = _ConnectSparkSession.getActiveSession() or sql.SparkSession.getActiveSession() # type: ignore
else:
Expand Down Expand Up @@ -292,14 +297,18 @@ def spark_data_type_is_array(data_type: DataType) -> bool: # type: ignore

def spark_data_type_is_numeric(data_type: DataType) -> bool: # type: ignore
"""Check if the column's dataType is of type ArrayType"""
return isinstance(data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType))
return isinstance(
data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)
)


def schema_struct_to_schema_str(schema: StructType) -> str:
"""Converts a StructType to a schema str"""
if not schema:
return ""
return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields])
return ",\n".join(
[f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]
)


def import_pandas_based_on_pyspark_version() -> ModuleType:
Expand All @@ -314,7 +323,9 @@ def import_pandas_based_on_pyspark_version() -> ModuleType:
pyspark_version = get_spark_minor_version()
pandas_version = pd.__version__

if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"):
if (pyspark_version < 3.4 and pandas_version >= "2") or (
pyspark_version >= 3.4 and pandas_version < "2"
):
raise ImportError(
f"For PySpark {pyspark_version}, "
f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}"
Expand Down Expand Up @@ -379,7 +390,10 @@ def get_column_name(col: Column) -> str: # type: ignore
# In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute
# noinspection PyProtectedMember
name = col._jc.toString() # type: ignore[operator]
elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)):
elif any(
cls.__module__ == "pyspark.sql.connect.column"
for cls in inspect.getmro(col.__class__)
):
# noinspection PyProtectedMember
name = col._expr.name()
else:
Expand Down
3 changes: 0 additions & 3 deletions tests/spark/integrations/dq/test_spark_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

pytestmark = pytest.mark.spark

if pyspark.__version__.startswith("3.5"):
pytestmark = pytest.mark.skip("Spark Expectations is not supported for Spark 3.5")


class TestSparkExpectationsTransform:
"""
Expand Down
98 changes: 93 additions & 5 deletions tests/spark/test_spark_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import environ
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -12,10 +12,89 @@
schema_struct_to_schema_str,
show_string,
)
from koheesio.spark.utils.common import (
check_if_pyspark_connect_is_supported,
get_active_session,
get_spark_minor_version,
)


class TestGetActiveSession:
def test_unhappy_get_active_session_spark_connect(self):
"""Test that get_active_session raises an error when no active session is found when using spark connect."""
with (
# ensure that we are forcing the code to think that we are using spark connect
patch(
"koheesio.spark.utils.common.check_if_pyspark_connect_is_supported",
return_value=True,
),
# make sure that spark session is not found
patch("pyspark.sql.SparkSession.getActiveSession", return_value=None),
):
session = MagicMock(
SparkSession=MagicMock(getActiveSession=MagicMock(return_value=None))
)
with patch.dict("sys.modules", {"pyspark.sql.connect": session}):
with pytest.raises(
RuntimeError,
match="No active Spark session found. Please create a Spark session before using module "
"connect_utils. Or perform local import of the module.",
):
get_active_session()

def test_unhappy_get_active_session(self):
"""Test that get_active_session raises an error when no active session is found."""
with (
patch(
"koheesio.spark.utils.common.check_if_pyspark_connect_is_supported",
return_value=False,
),
patch("pyspark.sql.SparkSession.getActiveSession", return_value=None),
):
with pytest.raises(
RuntimeError,
match="No active Spark session found. Please create a Spark session before using module connect_utils. "
"Or perform local import of the module.",
):
get_active_session()

def test_get_active_session_with_spark(self, spark):
"""Test get_active_session when an active session is found"""
session = get_active_session()
assert session is not None


class TestCheckIfPysparkConnectIsSupported:
def test_if_pyspark_connect_is_not_supported(self):
"""Test that check_if_pyspark_connect_is_supported returns False when pyspark connect is not supported."""
with patch.dict("sys.modules", {"pyspark.sql.connect": None}):
assert check_if_pyspark_connect_is_supported() is False

def test_check_if_pyspark_connect_is_supported(self):
"""Test that check_if_pyspark_connect_is_supported returns True when pyspark connect is supported."""
with (
patch("koheesio.spark.utils.common.SPARK_MINOR_VERSION", 3.5),
patch.dict(
"sys.modules",
{
"pyspark.sql.connect.column": MagicMock(Column=MagicMock()),
"pyspark.sql.connect": MagicMock(),
},
),
):
assert check_if_pyspark_connect_is_supported() is True


def test_get_spark_minor_version():
"""Test that get_spark_minor_version returns the correctly formatted version."""
with patch("koheesio.spark.utils.common.spark_version", "9.9.42"):
assert get_spark_minor_version() == 9.9


def test_schema_struct_to_schema_str():
struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())])
struct_schema = StructType(
[StructField("a", StringType()), StructField("b", StringType())]
)
val = schema_struct_to_schema_str(struct_schema)
assert val == "a STRING,\nb STRING"
assert schema_struct_to_schema_str(None) == ""
Expand All @@ -40,12 +119,21 @@ def test_on_databricks(env_var_value, expected_result):
(3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error
(3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error
(3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error
(3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error
(
3.4,
"1.2.3",
ImportError,
), # PySpark not 3.3, pandas < 2, should raise an error
],
)
def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error):
def test_import_pandas_based_on_pyspark_version(
spark_version, pandas_version, expected_error
):
with (
patch("koheesio.spark.utils.common.get_spark_minor_version", return_value=spark_version),
patch(
"koheesio.spark.utils.common.get_spark_minor_version",
return_value=spark_version,
),
patch("pandas.__version__", new=pandas_version),
):
if expected_error:
Expand Down

0 comments on commit 76207e7

Please sign in to comment.