Skip to content

Commit

Permalink
feat(airflow): create SparkSqlOnK8SOperator to execute Spark SQL comm…
Browse files Browse the repository at this point in the history
…ands (#109)
  • Loading branch information
hussein-awala authored Dec 3, 2024
1 parent 8ae31ca commit d178937
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 0 deletions.
49 changes: 49 additions & 0 deletions spark_on_k8s/airflow/operators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import contextlib
from pathlib import Path
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.template import literal
from spark_on_k8s.airflow.operator_links import SparkOnK8SOperatorLink
from spark_on_k8s.airflow.triggers import SparkOnK8STrigger
from spark_on_k8s.k8s.sync_client import KubernetesClientManager
Expand Down Expand Up @@ -501,3 +503,50 @@ def on_kill(self) -> None:
raise AirflowException(f"Invalid on_kill_action: {self.on_kill_action}")

self._persist_spark_history_ui_link(get_current_context())


class SparkSqlOnK8SOperator(SparkOnK8SOperator):
"""Execute Spark SQL commands on Kubernetes.
This operator passes the SQL commands to a Python script that runs them on Spark,
that's why it requires a PySpark docker image.
Args:
sql (str): SQL commands to execute.
**kwargs: Other keyword arguments for SparkOnK8SOperator.
"""

template_fields = (
"sql",
*SparkOnK8SOperator.template_fields,
)
template_ext = (".sql",)
template_fields_renderers = {
"sql": "sql",
}

def __init__(self, sql: str, **kwargs):
super().__init__(
app_path="/configmap/spark_sql.py",
app_arguments=literal(["/configmap/queries.sql"]),
**kwargs,
)
self.sql = sql

def execute(self, context: Context):
self.driver_ephemeral_configmaps_volumes = [
{
"mount_path": "/configmap",
"sources": [
{
"name": "spark_sql.py",
"text_path": f"{Path(__file__).parent}/scripts/spark_sql.py",
},
{
"name": "queries.sql",
"text": self.sql,
},
],
}
]
return super().execute(context)
50 changes: 50 additions & 0 deletions spark_on_k8s/airflow/scripts/spark_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import logging

from pyspark.sql import SparkSession

logger = logging.getLogger("SparkSqlOnK8s")


def read_sql_file(file_path):
"""
Reads an SQL file and splits commands by semicolon ";".
"""
with open(file_path) as f:
file_content = f.read()
# Split commands by semicolon while ignoring empty lines
return [command.strip() for command in file_content.split(";") if command.strip()]


def execute_sql_commands(spark_session: SparkSession, commands):
"""
Executes a list of SQL commands on a Spark session.
"""
for command in commands:
try:
logger.info(f"Executing: {command}")
spark_session.sql(command).show(truncate=False)
except Exception:
logger.exception(f"Error executing SQL command: {command}")


if __name__ == "__main__":
import sys

if len(sys.argv) != 2:
logger.error("Usage: spark-submit run_sql_commands.py <path_to_sql_file>")
sys.exit(1)

sql_file_path = sys.argv[1]

_spark_session = SparkSession.builder.getOrCreate()

# Read SQL commands from file
sql_commands = read_sql_file(sql_file_path)

# Execute each SQL command
execute_sql_commands(_spark_session, sql_commands)

# Stop SparkSession
_spark_session.stop()
71 changes: 71 additions & 0 deletions tests/airflow/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,74 @@ def test_startup_timeout(
poll_interval=10,
startup_timeout=10,
)


@pytest.mark.skipif(PYTHON_313_OR_ABOVE, reason="Python 3.13+ is not supported by Airflow")
class TestSparkSqlOnK8SOperator:
@mock.patch("spark_on_k8s.client.SparkOnK8S.submit_app")
def test_rendering_templates(self, mock_submit_app, root_path):
from spark_on_k8s.airflow.operators import SparkSqlOnK8SOperator

spark_app_task = SparkSqlOnK8SOperator(
task_id="spark_sql_application",
image="{{ template_image }}",
sql="SELECT * FROM {{ template_table }}",
app_waiter="no_wait",
)
spark_app_task.render_template_fields(
context={
"template_image": "pyspark-job",
"template_table": "test_table",
},
)
spark_app_task.execute(
{
"ti": mock.MagicMock(
xcom_pull=mock.MagicMock(return_value=None),
)
}
)
mock_submit_app.assert_called_once_with(
image="pyspark-job",
app_path="/configmap/spark_sql.py",
driver_ephemeral_configmaps_volumes=[
{
"mount_path": "/configmap",
"sources": [
{
"name": "spark_sql.py",
"text_path": f"{root_path}/spark_on_k8s/airflow/scripts/spark_sql.py",
},
{
"name": "queries.sql",
"text": "SELECT * FROM test_table",
},
],
}
],
namespace="default",
service_account="spark",
app_arguments=["/configmap/queries.sql"],
app_waiter="no_wait",
image_pull_policy="IfNotPresent",
app_name=None,
spark_conf=None,
class_name=None,
packages=None,
ui_reverse_proxy=False,
driver_resources=None,
executor_resources=None,
executor_instances=None,
secret_values=None,
volumes=None,
driver_volume_mounts=None,
executor_volume_mounts=None,
driver_node_selector=None,
executor_node_selector=None,
driver_labels=None,
executor_labels=None,
driver_annotations=None,
executor_annotations=None,
driver_tolerations=None,
executor_pod_template_path=None,
)
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from __future__ import annotations

import subprocess
import sys

import pytest

PYTHON_313_OR_ABOVE = sys.version_info.major == 3 and sys.version_info.minor >= 13


@pytest.fixture(scope="session")
def root_path() -> str:
return (
subprocess.Popen(["git", "rev-parse", "--show-toplevel"], stdout=subprocess.PIPE)
.communicate()[0]
.rstrip()
.decode("utf-8")
)

0 comments on commit d178937

Please sign in to comment.