diff --git a/.changes/unreleased/Features-20230113-133859.yaml b/.changes/unreleased/Features-20230113-133859.yaml new file mode 100644 index 000000000..21a38fe1f --- /dev/null +++ b/.changes/unreleased/Features-20230113-133859.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add extras Spark config to thrift connection +time: 2023-01-13T13:38:59.257521+07:00 +custom: + Author: Vinh Nguyen + Issue: "590" + PR: "591" diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 2a7f8188d..02b62090d 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -74,6 +74,7 @@ class SparkCredentials(Credentials): connect_timeout: int = 10 use_ssl: bool = False server_side_parameters: Dict[str, Any] = field(default_factory=dict) + spark_conf_string: Optional[str] = None retry_all: bool = False @classmethod @@ -381,6 +382,8 @@ def open(cls, connection: Connection) -> Connection: elif creds.method == SparkConnectionMethod.THRIFT: cls.validate_creds(creds, ["host", "port", "user", "schema"]) + configuration = None if not creds.spark_conf_string else parse_spark_config(creds.spark_conf_string) + if creds.use_ssl: transport = build_ssl_transport( host=creds.host, @@ -390,7 +393,7 @@ def open(cls, connection: Connection) -> Connection: kerberos_service_name=creds.kerberos_service_name, password=creds.password, ) - conn = hive.connect(thrift_transport=transport) + conn = hive.connect(thrift_transport=transport, configuration=configuration) else: conn = hive.connect( host=creds.host, @@ -399,6 +402,7 @@ def open(cls, connection: Connection) -> Connection: auth=creds.auth, kerberos_service_name=creds.kerberos_service_name, password=creds.password, + configuration=configuration ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: @@ -571,3 +575,12 @@ def _is_retryable_error(exc: Exception) -> str: return str(exc) else: return "" + + +def parse_spark_config(spark_conf_string: str) -> Dict: + try: + return dict(map(lambda x: x.split('='), spark_conf_string.split(';'))) + except: + raise dbt.exceptions.DbtProfileError( + f"invalid spark_conf_string: {spark_conf_string}. Parse error." + ) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3c7fccd35..bb782e94b 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -102,6 +102,22 @@ def _get_target_use_ssl_thrift(self, project): }, ) + def _get_target_thrift_with_spark_conf_string(self, project): + return config_from_parts_or_dicts(project, { + 'outputs': { + 'test': { + 'type': 'spark', + 'method': 'thrift', + 'schema': 'analytics', + 'host': 'myorg.sparkhost.com', + 'port': 10001, + 'user': 'dbt', + 'spark_conf_string': 'spark.executor.memory=1g;spark.executor.cores=1' + } + }, + 'target': 'test' + }) + def _get_target_odbc_cluster(self, project): return config_from_parts_or_dicts( project, @@ -228,6 +244,29 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name, passw self.assertEqual(connection.credentials.schema, "analytics") self.assertIsNone(connection.credentials.database) + def test_thrift_connection_with_spark_conf_string(self): + config = self._get_target_thrift_with_spark_conf_string(self.project_cfg) + adapter = SparkAdapter(config) + + def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + self.assertEqual(host, 'myorg.sparkhost.com') + self.assertEqual(port, 10001) + self.assertEqual(username, 'dbt') + self.assertIsNone(auth) + self.assertIsNone(kerberos_service_name) + self.assertIsNone(password) + + with mock.patch.object(hive, 'connect', new=hive_thrift_connect): + connection = adapter.acquire_connection('dummy') + connection.handle # trigger lazy-load + + self.assertEqual(connection.state, 'open') + self.assertIsNotNone(connection.handle) + self.assertEqual(connection.credentials.schema, 'analytics') + self.assertIsNone(connection.credentials.database) + spark_conf = {'spark.executor.memory': '1g', 'spark.executor.cores': '1'} + self.assertEqual(connection.credentials.configuration, spark_conf) + def test_odbc_cluster_connection(self): config = self._get_target_odbc_cluster(self.project_cfg) adapter = SparkAdapter(config)