Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add spark_conf_string to thrift connection type #591

Closed
wants to merge 7 commits into from
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20230113-133859.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add extras Spark config to thrift connection
time: 2023-01-13T13:38:59.257521+07:00
custom:
Author: Vinh Nguyen
Issue: "590"
PR: "591"
15 changes: 14 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class SparkCredentials(Credentials):
connect_timeout: int = 10
use_ssl: bool = False
server_side_parameters: Dict[str, Any] = field(default_factory=dict)
spark_conf_string: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be a string? Can this be a dict? Doing so would allow us to escape special characters in the configuration

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, my purpose is to be able to change the Spark configuration with profile.yml when connecting via Thrift connection. These configs will be adjusted according to the DBT models.

To avoid change detection on profile.yml and do a full re-parse(docs) . The simplest way is to pass these configs from environment variables through the jinja format(ref). spark_conf_string: "{{ env_var('SPARK_CONFIG_STRING') }}"

With the server_side_parameters approach, its data type is dict, it is difficult to bypass the re-parse limitation if we want to use the environment variable mentioned above.

With these specific requirements, it led me to create another parameter with the String data type.

Hi @colin-rogers-dbt ,
As I mentioned in the previous comment, the main reason is to prevent full-reparse when we need to change the spark configuration via dbt profile sequentially

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad, should have made the connection.

retry_all: bool = False

@classmethod
Expand Down Expand Up @@ -370,6 +371,8 @@ def open(cls, connection):
elif creds.method == SparkConnectionMethod.THRIFT:
cls.validate_creds(creds, ["host", "port", "user", "schema"])

configuration = None if not creds.spark_conf_string else parse_spark_config(creds.spark_conf_string)

if creds.use_ssl:
transport = build_ssl_transport(
host=creds.host,
Expand All @@ -379,7 +382,7 @@ def open(cls, connection):
kerberos_service_name=creds.kerberos_service_name,
password=creds.password,
)
conn = hive.connect(thrift_transport=transport)
conn = hive.connect(thrift_transport=transport, configuration=configuration)
else:
conn = hive.connect(
host=creds.host,
Expand All @@ -388,6 +391,7 @@ def open(cls, connection):
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name,
password=creds.password,
configuration=configuration
) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
Expand Down Expand Up @@ -540,3 +544,12 @@ def _is_retryable_error(exc: Exception) -> str:
return str(exc)
else:
return ""


def parse_spark_config(spark_conf_string: str) -> Dict:
try:
return dict(map(lambda x: x.split('='), spark_conf_string.split(';')))
except:
raise dbt.exceptions.DbtProfileError(
f"invalid spark_conf_string: {spark_conf_string}. Parse error."
)
39 changes: 39 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def _get_target_use_ssl_thrift(self, project):
'target': 'test'
})

def _get_target_thrift_with_spark_conf_string(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt',
'spark_conf_string': 'spark.executor.memory=1g;spark.executor.cores=1'
}
},
'target': 'test'
})

def _get_target_odbc_cluster(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
Expand Down Expand Up @@ -211,6 +227,29 @@ def hive_thrift_connect(host, port, username, auth, kerberos_service_name, passw
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_connection_with_spark_conf_string(self):
config = self._get_target_thrift_with_spark_conf_string(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, 'dbt')
self.assertIsNone(auth)
self.assertIsNone(kerberos_service_name)
self.assertIsNone(password)

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
connection.handle # trigger lazy-load

self.assertEqual(connection.state, 'open')
self.assertIsNotNone(connection.handle)
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)
spark_conf = {'spark.executor.memory': '1g', 'spark.executor.cores': '1'}
self.assertEqual(connection.credentials.configuration, spark_conf)

def test_odbc_cluster_connection(self):
config = self._get_target_odbc_cluster(self.project_cfg)
adapter = SparkAdapter(config)
Expand Down