Skip to content

Commit

Permalink
add extra jdbc params
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwit-nx committed Dec 22, 2024
1 parent a8c4871 commit 0ee257d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 4 deletions.
32 changes: 30 additions & 2 deletions providers/src/airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def __init__(
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to Hive Client Wrapper connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import BooleanField, StringField
from wtforms import BooleanField, PasswordField, StringField

return {
"use_beeline": BooleanField(lazy_gettext("Use Beeline"), default=True),
Expand All @@ -131,6 +131,15 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
lazy_gettext("Principal"), widget=BS3TextFieldWidget(), default="hive/[email protected]"
),
"high_availability": BooleanField(lazy_gettext("High Availability mode"), default=False),
"ssl_trust_store": StringField(
lazy_gettext("SSL trust store"), widget=BS3TextFieldWidget(), default=""
),
"ssl_trust_store_password": PasswordField(
lazy_gettext("SSL trust store password"), widget=BS3PasswordFieldWidget(), default=""
),
"transport_mode": StringField(
lazy_gettext("Transport mode"), widget=BS3TextFieldWidget(), default=""
),
}

@classmethod
Expand Down Expand Up @@ -183,6 +192,25 @@ def _prepare_cli_cmd(self) -> list[Any]:
elif self.auth:
jdbc_url += ";auth=" + self.auth

ssl_trust_store = conn.extra_dejson.get("ssl_trust_store", "")
if ssl_trust_store:
if ";" in ssl_trust_store:
raise RuntimeError("The SSL trust store should not contain the ';' character")
jdbc_url += ";sslTrustStore=" + ssl_trust_store

transport_mode = conn.extra_dejson.get("transport_mode", "")
if transport_mode:
if ";" in transport_mode:
raise RuntimeError("The transport mode should not contain the ';' character")
jdbc_url += ";transportMode=" + transport_mode

ssl_trust_store_password = conn.extra_dejson.get("ssl_trust_store_password", "")
if ssl_trust_store_password:
from urllib.parse import quote

ssl_trust_store_password = quote(ssl_trust_store_password, safe="")
jdbc_url += f";trustStorePassword={ssl_trust_store_password}"

jdbc_url = f'"{jdbc_url}"'

cmd_extra += ["-u", jdbc_url]
Expand Down
76 changes: 74 additions & 2 deletions providers/tests/apache/hive/hooks/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,57 @@ def __init__(self):

@pytest.mark.db_test
class TestHiveCliHook:
@pytest.mark.parametrize(
"extra_dejson, expected_jdbc_url",
[
(
{
"ssl_trust_store": "",
"transport_mode": "",
"ssl_trust_store_password": "",
},
"jdbc:hive2://localhost:10000/default",
),
(
{
"ssl_trust_store": "/path/to/truststore",
"transport_mode": "http",
"ssl_trust_store_password": "@password123;",
},
"jdbc:hive2://localhost:10000/default;sslTrustStore=/path/to/truststore;transportMode=http;trustStorePassword=%40password123%3B",
),
(
{"ssl_trust_store": "", "transport_mode": "http", "ssl_trust_store_password": ""},
"jdbc:hive2://localhost:10000/default;transportMode=http",
),
(
{
"ssl_trust_store": "/path/to/truststore",
"transport_mode": "",
"ssl_trust_store_password": "",
},
"jdbc:hive2://localhost:10000/default;sslTrustStore=/path/to/truststore",
),
(
{
"ssl_trust_store": "",
"transport_mode": "",
"ssl_trust_store_password": "!@#$%^&*()_+-=,<.>/?[{]}:'",
},
"jdbc:hive2://localhost:10000/default;trustStorePassword=%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D%2C%3C.%3E%2F%3F%5B%7B%5D%7D%3A%27",
),
],
)
@mock.patch("tempfile.tempdir", "/tmp/")
@mock.patch("tempfile._RandomNameSequence.__next__")
@mock.patch("subprocess.Popen")
def test_run_cli(self, mock_popen, mock_temp_dir):
def test_run_cli(
self,
mock_popen,
mock_temp_dir,
extra_dejson,
expected_jdbc_url,
):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
mock_temp_dir.return_value = "test_run_cli"
Expand All @@ -79,12 +126,14 @@ def test_run_cli(self, mock_popen, mock_temp_dir):
},
):
hook = MockHiveCliHook()
hook.conn.extra_dejson = extra_dejson

hook.run_cli("SHOW DATABASES")
date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
hive_cmd = [
"beeline",
"-u",
'"jdbc:hive2://localhost:10000/default"',
f'"{expected_jdbc_url}"',
"-hiveconf",
"airflow.ctx.dag_id=test_dag_id",
"-hiveconf",
Expand Down Expand Up @@ -205,6 +254,7 @@ def test_run_cli_with_hive_conf(self, mock_popen):
},
):
hook = MockHiveCliHook()
hook.conn.extra_dejson = {}
mock_popen.return_value = MockSubProcess(output=mock_output)

output = hook.run_cli(hql=hql, hive_conf={"key": "value"})
Expand Down Expand Up @@ -969,3 +1019,25 @@ def test_high_availability(self, extra_dejson, expected_keys):
assert expected_keys in result[2]
else:
assert expected_keys not in result[2]

def test_get_wrong_ssl_trust_store(self):
hook = MockHiveCliHook()
returner = mock.MagicMock()
returner.extra_dejson = {"ssl_trust_store": "SSL trust store with ; semicolon"}
hook.use_beeline = True
hook.conn = returner

# Run
with pytest.raises(RuntimeError, match="The SSL trust store should not contain the ';' character"):
hook._prepare_cli_cmd()

def test_get_wrong_transport_mode(self):
hook = MockHiveCliHook()
returner = mock.MagicMock()
returner.extra_dejson = {"transport_mode": "Transport mode with ; semicolon"}
hook.use_beeline = True
hook.conn = returner

# Run
with pytest.raises(RuntimeError, match="The transport mode should not contain the ';' character"):
hook._prepare_cli_cmd()

0 comments on commit 0ee257d

Please sign in to comment.