From 0ee257d6650dbe812e032b952d7904b69de2212d Mon Sep 17 00:00:00 2001 From: Tomasz Witwicki Date: Wed, 18 Dec 2024 23:15:45 +0100 Subject: [PATCH] add extra jdbc params --- .../providers/apache/hive/hooks/hive.py | 32 +++++++- .../tests/apache/hive/hooks/test_hive.py | 76 ++++++++++++++++++- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/providers/src/airflow/providers/apache/hive/hooks/hive.py b/providers/src/airflow/providers/apache/hive/hooks/hive.py index dde421e01e69b..3b6d0da51bb96 100644 --- a/providers/src/airflow/providers/apache/hive/hooks/hive.py +++ b/providers/src/airflow/providers/apache/hive/hooks/hive.py @@ -120,9 +120,9 @@ def __init__( @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to Hive Client Wrapper connection form.""" - from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import BooleanField, StringField + from wtforms import BooleanField, PasswordField, StringField return { "use_beeline": BooleanField(lazy_gettext("Use Beeline"), default=True), @@ -131,6 +131,15 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: lazy_gettext("Principal"), widget=BS3TextFieldWidget(), default="hive/_HOST@EXAMPLE.COM" ), "high_availability": BooleanField(lazy_gettext("High Availability mode"), default=False), + "ssl_trust_store": StringField( + lazy_gettext("SSL trust store"), widget=BS3TextFieldWidget(), default="" + ), + "ssl_trust_store_password": PasswordField( + lazy_gettext("SSL trust store password"), widget=BS3PasswordFieldWidget(), default="" + ), + "transport_mode": StringField( + lazy_gettext("Transport mode"), widget=BS3TextFieldWidget(), default="" + ), } @classmethod @@ -183,6 +192,25 @@ def _prepare_cli_cmd(self) -> list[Any]: elif self.auth: jdbc_url += ";auth=" + self.auth + ssl_trust_store = conn.extra_dejson.get("ssl_trust_store", "") + if ssl_trust_store: + if ";" in ssl_trust_store: + raise RuntimeError("The SSL trust store should not contain the ';' character") + jdbc_url += ";sslTrustStore=" + ssl_trust_store + + transport_mode = conn.extra_dejson.get("transport_mode", "") + if transport_mode: + if ";" in transport_mode: + raise RuntimeError("The transport mode should not contain the ';' character") + jdbc_url += ";transportMode=" + transport_mode + + ssl_trust_store_password = conn.extra_dejson.get("ssl_trust_store_password", "") + if ssl_trust_store_password: + from urllib.parse import quote + + ssl_trust_store_password = quote(ssl_trust_store_password, safe="") + jdbc_url += f";trustStorePassword={ssl_trust_store_password}" + jdbc_url = f'"{jdbc_url}"' cmd_extra += ["-u", jdbc_url] diff --git a/providers/tests/apache/hive/hooks/test_hive.py b/providers/tests/apache/hive/hooks/test_hive.py index c7913a9406c8b..05caf19c21885 100644 --- a/providers/tests/apache/hive/hooks/test_hive.py +++ b/providers/tests/apache/hive/hooks/test_hive.py @@ -57,10 +57,57 @@ def __init__(self): @pytest.mark.db_test class TestHiveCliHook: + @pytest.mark.parametrize( + "extra_dejson, expected_jdbc_url", + [ + ( + { + "ssl_trust_store": "", + "transport_mode": "", + "ssl_trust_store_password": "", + }, + "jdbc:hive2://localhost:10000/default", + ), + ( + { + "ssl_trust_store": "/path/to/truststore", + "transport_mode": "http", + "ssl_trust_store_password": "@password123;", + }, + "jdbc:hive2://localhost:10000/default;sslTrustStore=/path/to/truststore;transportMode=http;trustStorePassword=%40password123%3B", + ), + ( + {"ssl_trust_store": "", "transport_mode": "http", "ssl_trust_store_password": ""}, + "jdbc:hive2://localhost:10000/default;transportMode=http", + ), + ( + { + "ssl_trust_store": "/path/to/truststore", + "transport_mode": "", + "ssl_trust_store_password": "", + }, + "jdbc:hive2://localhost:10000/default;sslTrustStore=/path/to/truststore", + ), + ( + { + "ssl_trust_store": "", + "transport_mode": "", + "ssl_trust_store_password": "!@#$%^&*()_+-=,<.>/?[{]}:'", + }, + "jdbc:hive2://localhost:10000/default;trustStorePassword=%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D%2C%3C.%3E%2F%3F%5B%7B%5D%7D%3A%27", + ), + ], + ) @mock.patch("tempfile.tempdir", "/tmp/") @mock.patch("tempfile._RandomNameSequence.__next__") @mock.patch("subprocess.Popen") - def test_run_cli(self, mock_popen, mock_temp_dir): + def test_run_cli( + self, + mock_popen, + mock_temp_dir, + extra_dejson, + expected_jdbc_url, + ): mock_subprocess = MockSubProcess() mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "test_run_cli" @@ -79,12 +126,14 @@ def test_run_cli(self, mock_popen, mock_temp_dir): }, ): hook = MockHiveCliHook() + hook.conn.extra_dejson = extra_dejson + hook.run_cli("SHOW DATABASES") date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" hive_cmd = [ "beeline", "-u", - '"jdbc:hive2://localhost:10000/default"', + f'"{expected_jdbc_url}"', "-hiveconf", "airflow.ctx.dag_id=test_dag_id", "-hiveconf", @@ -205,6 +254,7 @@ def test_run_cli_with_hive_conf(self, mock_popen): }, ): hook = MockHiveCliHook() + hook.conn.extra_dejson = {} mock_popen.return_value = MockSubProcess(output=mock_output) output = hook.run_cli(hql=hql, hive_conf={"key": "value"}) @@ -969,3 +1019,25 @@ def test_high_availability(self, extra_dejson, expected_keys): assert expected_keys in result[2] else: assert expected_keys not in result[2] + + def test_get_wrong_ssl_trust_store(self): + hook = MockHiveCliHook() + returner = mock.MagicMock() + returner.extra_dejson = {"ssl_trust_store": "SSL trust store with ; semicolon"} + hook.use_beeline = True + hook.conn = returner + + # Run + with pytest.raises(RuntimeError, match="The SSL trust store should not contain the ';' character"): + hook._prepare_cli_cmd() + + def test_get_wrong_transport_mode(self): + hook = MockHiveCliHook() + returner = mock.MagicMock() + returner.extra_dejson = {"transport_mode": "Transport mode with ; semicolon"} + hook.use_beeline = True + hook.conn = returner + + # Run + with pytest.raises(RuntimeError, match="The transport mode should not contain the ';' character"): + hook._prepare_cli_cmd()