From 3c70ac73cfae0454902db5a2275922b579d93944 Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Fri, 13 Jan 2023 13:01:35 +0700 Subject: [PATCH 1/2] add spark_conf_string to thrift connection type --- .../unreleased/Features-20230113-133859.yaml | 7 ++++ dbt/adapters/spark/connections.py | 15 ++++++- tests/unit/test_adapter.py | 39 +++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 .changes/unreleased/Features-20230113-133859.yaml diff --git a/.changes/unreleased/Features-20230113-133859.yaml b/.changes/unreleased/Features-20230113-133859.yaml new file mode 100644 index 000000000..5abbc08c3 --- /dev/null +++ b/.changes/unreleased/Features-20230113-133859.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add extras Spark config to thrift connection +time: 2023-01-13T13:38:59.257521+07:00 +custom: + Author: Vinh Nguyen + Issue: "590" + PR: "0" diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index a606beb78..f42fd1e77 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -74,6 +74,7 @@ class SparkCredentials(Credentials): connect_timeout: int = 10 use_ssl: bool = False server_side_parameters: Dict[str, Any] = field(default_factory=dict) + spark_conf_string: Optional[str] = None retry_all: bool = False @classmethod @@ -370,6 +371,8 @@ def open(cls, connection): elif creds.method == SparkConnectionMethod.THRIFT: cls.validate_creds(creds, ["host", "port", "user", "schema"]) + configuration = None if not creds.spark_conf_string else parse_spark_config(creds.spark_conf_string) + if creds.use_ssl: transport = build_ssl_transport( host=creds.host, @@ -379,7 +382,7 @@ def open(cls, connection): kerberos_service_name=creds.kerberos_service_name, password=creds.password, ) - conn = hive.connect(thrift_transport=transport) + conn = hive.connect(thrift_transport=transport, configuration=configuration) else: conn = hive.connect( host=creds.host, @@ -388,6 +391,7 @@ def open(cls, connection): auth=creds.auth, kerberos_service_name=creds.kerberos_service_name, password=creds.password, + configuration=configuration ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: @@ -540,3 +544,12 @@ def _is_retryable_error(exc: Exception) -> str: return str(exc) else: return "" + + +def parse_spark_config(spark_conf_string: str) -> Dict: + try: + return dict(map(lambda x: x.split('='), spark_conf_string.split(';'))) + except: + raise dbt.exceptions.DbtProfileError( + f"invalid spark_conf_string: {spark_conf_string}. Parse error." + ) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index d24bc8a2f..39c9b0f52 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -91,6 +91,22 @@ def _get_target_use_ssl_thrift(self, project): 'target': 'test' }) + def _get_target_thrift_with_spark_conf_string(self, project): + return config_from_parts_or_dicts(project, { + 'outputs': { + 'test': { + 'type': 'spark', + 'method': 'thrift', + 'schema': 'analytics', + 'host': 'myorg.sparkhost.com', + 'port': 10001, + 'user': 'dbt', + 'spark_conf_string': 'spark.executor.memory=1g;spark.executor.cores=1' + } + }, + 'target': 'test' + }) + def _get_target_odbc_cluster(self, project): return config_from_parts_or_dicts(project, { 'outputs': { @@ -211,6 +227,29 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name, passw self.assertEqual(connection.credentials.schema, 'analytics') self.assertIsNone(connection.credentials.database) + def test_thrift_connection_with_spark_conf_string(self): + config = self._get_target_thrift_with_spark_conf_string(self.project_cfg) + adapter = SparkAdapter(config) + + def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + self.assertEqual(host, 'myorg.sparkhost.com') + self.assertEqual(port, 10001) + self.assertEqual(username, 'dbt') + self.assertIsNone(auth) + self.assertIsNone(kerberos_service_name) + self.assertIsNone(password) + + with mock.patch.object(hive, 'connect', new=hive_thrift_connect): + connection = adapter.acquire_connection('dummy') + connection.handle # trigger lazy-load + + self.assertEqual(connection.state, 'open') + self.assertIsNotNone(connection.handle) + self.assertEqual(connection.credentials.schema, 'analytics') + self.assertIsNone(connection.credentials.database) + spark_conf = {'spark.executor.memory': '1g', 'spark.executor.cores': '1'} + self.assertEqual(connection.credentials.configuration, spark_conf) + def test_odbc_cluster_connection(self): config = self._get_target_odbc_cluster(self.project_cfg) adapter = SparkAdapter(config) From 67c782d3eef9c65bbc39acb93678a2fe2360d679 Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Fri, 13 Jan 2023 13:01:35 +0700 Subject: [PATCH 2/2] add spark_conf_string to thrift connection type --- .../unreleased/Features-20230113-133859.yaml | 7 ++++ dbt/adapters/spark/connections.py | 15 ++++++- tests/unit/test_adapter.py | 39 +++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 .changes/unreleased/Features-20230113-133859.yaml diff --git a/.changes/unreleased/Features-20230113-133859.yaml b/.changes/unreleased/Features-20230113-133859.yaml new file mode 100644 index 000000000..21a38fe1f --- /dev/null +++ b/.changes/unreleased/Features-20230113-133859.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add extras Spark config to thrift connection +time: 2023-01-13T13:38:59.257521+07:00 +custom: + Author: Vinh Nguyen + Issue: "590" + PR: "591" diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index a606beb78..f42fd1e77 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -74,6 +74,7 @@ class SparkCredentials(Credentials): connect_timeout: int = 10 use_ssl: bool = False server_side_parameters: Dict[str, Any] = field(default_factory=dict) + spark_conf_string: Optional[str] = None retry_all: bool = False @classmethod @@ -370,6 +371,8 @@ def open(cls, connection): elif creds.method == SparkConnectionMethod.THRIFT: cls.validate_creds(creds, ["host", "port", "user", "schema"]) + configuration = None if not creds.spark_conf_string else parse_spark_config(creds.spark_conf_string) + if creds.use_ssl: transport = build_ssl_transport( host=creds.host, @@ -379,7 +382,7 @@ def open(cls, connection): kerberos_service_name=creds.kerberos_service_name, password=creds.password, ) - conn = hive.connect(thrift_transport=transport) + conn = hive.connect(thrift_transport=transport, configuration=configuration) else: conn = hive.connect( host=creds.host, @@ -388,6 +391,7 @@ def open(cls, connection): auth=creds.auth, kerberos_service_name=creds.kerberos_service_name, password=creds.password, + configuration=configuration ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: @@ -540,3 +544,12 @@ def _is_retryable_error(exc: Exception) -> str: return str(exc) else: return "" + + +def parse_spark_config(spark_conf_string: str) -> Dict: + try: + return dict(map(lambda x: x.split('='), spark_conf_string.split(';'))) + except: + raise dbt.exceptions.DbtProfileError( + f"invalid spark_conf_string: {spark_conf_string}. Parse error." + ) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index d24bc8a2f..39c9b0f52 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -91,6 +91,22 @@ def _get_target_use_ssl_thrift(self, project): 'target': 'test' }) + def _get_target_thrift_with_spark_conf_string(self, project): + return config_from_parts_or_dicts(project, { + 'outputs': { + 'test': { + 'type': 'spark', + 'method': 'thrift', + 'schema': 'analytics', + 'host': 'myorg.sparkhost.com', + 'port': 10001, + 'user': 'dbt', + 'spark_conf_string': 'spark.executor.memory=1g;spark.executor.cores=1' + } + }, + 'target': 'test' + }) + def _get_target_odbc_cluster(self, project): return config_from_parts_or_dicts(project, { 'outputs': { @@ -211,6 +227,29 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name, passw self.assertEqual(connection.credentials.schema, 'analytics') self.assertIsNone(connection.credentials.database) + def test_thrift_connection_with_spark_conf_string(self): + config = self._get_target_thrift_with_spark_conf_string(self.project_cfg) + adapter = SparkAdapter(config) + + def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + self.assertEqual(host, 'myorg.sparkhost.com') + self.assertEqual(port, 10001) + self.assertEqual(username, 'dbt') + self.assertIsNone(auth) + self.assertIsNone(kerberos_service_name) + self.assertIsNone(password) + + with mock.patch.object(hive, 'connect', new=hive_thrift_connect): + connection = adapter.acquire_connection('dummy') + connection.handle # trigger lazy-load + + self.assertEqual(connection.state, 'open') + self.assertIsNotNone(connection.handle) + self.assertEqual(connection.credentials.schema, 'analytics') + self.assertIsNone(connection.credentials.database) + spark_conf = {'spark.executor.memory': '1g', 'spark.executor.cores': '1'} + self.assertEqual(connection.credentials.configuration, spark_conf) + def test_odbc_cluster_connection(self): config = self._get_target_odbc_cluster(self.project_cfg) adapter = SparkAdapter(config)