From 6141f77ac4d3ab56dd986f2aa356d4bf6dcc8df0 Mon Sep 17 00:00:00 2001 From: nilan3 Date: Tue, 27 Aug 2024 12:30:14 +0100 Subject: [PATCH 01/10] add support for extra odbc connection properties --- .github/workflows/integration.yml | 1 + .github/workflows/release-internal.yml | 1 + .github/workflows/release-prep.yml | 1 + dagger/run_dbt_spark_tests.py | 2 +- dbt/adapters/spark/connections.py | 56 ++++++++++++------- tests/conftest.py | 15 +++++ .../test_incremental_on_schema_change.py | 4 +- .../test_incremental_strategies.py | 6 +- tests/functional/adapter/test_constraints.py | 6 +- tests/functional/adapter/test_python_model.py | 8 +-- .../adapter/test_store_test_failures.py | 2 +- 11 files changed, 67 insertions(+), 35 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 699d45391..35bd9cae0 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -76,6 +76,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/.github/workflows/release-internal.yml b/.github/workflows/release-internal.yml index d4e7a3c93..1a5090312 100644 --- a/.github/workflows/release-internal.yml +++ b/.github/workflows/release-internal.yml @@ -79,6 +79,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/.github/workflows/release-prep.yml b/.github/workflows/release-prep.yml index 9cb2c3e19..9937463d3 100644 --- a/.github/workflows/release-prep.yml +++ b/.github/workflows/release-prep.yml @@ -482,6 +482,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 15f9cf2c2..67fa56587 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -137,7 +137,7 @@ async def test_spark(test_args): spark_ctr, spark_host = get_spark_container(client) tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) - elif test_profile in ["databricks_cluster", "databricks_sql_endpoint"]: + elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]: tst_container = ( tst_container.with_workdir("/") .with_exec(["./scripts/configure_odbc.sh"]) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 0405eaf5b..f5dcbe6a1 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -78,6 +78,7 @@ class SparkCredentials(Credentials): auth: Optional[str] = None kerberos_service_name: Optional[str] = None organization: str = "0" + connection_str_extra: Optional[str] = None connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -154,7 +155,7 @@ def __post_init__(self) -> None: f"ImportError({e.msg})" ) from e - if self.method != SparkConnectionMethod.SESSION: + if self.method != SparkConnectionMethod.SESSION and self.host is not None: self.host = self.host.rstrip("/") self.server_side_parameters = { @@ -483,38 +484,51 @@ def open(cls, connection: Connection) -> Connection: http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) + elif creds.connection_str_extra is not None: + required_fields = ["driver", "host", "port", "connection_str_extra"] else: raise DbtConfigError( - "Either `cluster` or `endpoint` must set when" + "Either `cluster`, `endpoint`, `connection_str_extra` must set when" " using the odbc method to connect to Spark" ) cls.validate_creds(creds, required_fields) - dbt_spark_version = __version__.version user_agent_entry = ( f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa ) - # http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()} - - # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm - connection_str = _build_odbc_connnection_string( - DRIVER=creds.driver, - HOST=creds.host, - PORT=creds.port, - UID="token", - PWD=creds.token, - HTTPPath=http_path, - AuthMech=3, - SparkServerType=3, - ThriftTransport=2, - SSL=1, - UserAgentEntry=user_agent_entry, - LCaseSspKeyName=0 if ssp else 1, - **ssp, - ) + if creds.token is not None: + # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm + connection_str = _build_odbc_connnection_string( + DRIVER=creds.driver, + HOST=creds.host, + PORT=creds.port, + UID="token", + PWD=creds.token, + HTTPPath=http_path, + AuthMech=3, + SparkServerType=3, + ThriftTransport=2, + SSL=1, + UserAgentEntry=user_agent_entry, + LCaseSspKeyName=0 if ssp else 1, + **ssp, + ) + else: + connection_str = _build_odbc_connnection_string( + DRIVER=creds.driver, + HOST=creds.host, + PORT=creds.port, + ThriftTransport=2, + SSL=1, + UserAgentEntry=user_agent_entry, + LCaseSspKeyName=0 if ssp else 1, + **ssp, + ) + if creds.connection_str_extra is not None: + connection_str = connection_str + ";" + creds.connection_str_extra conn = pyodbc.connect(connection_str, autocommit=True) handle = PyodbcConnectionWrapper(conn) diff --git a/tests/conftest.py b/tests/conftest.py index efba41a5f..fe4174f5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,8 @@ def dbt_profile_target(request): target = databricks_http_cluster_target() elif profile_type == "spark_session": target = spark_session_target() + elif profile_type == "spark_http_odbc": + target = spark_http_odbc_target() else: raise ValueError(f"Invalid profile type '{profile_type}'") return target @@ -101,6 +103,19 @@ def spark_session_target(): "method": "session", } +def spark_http_odbc_target(): + return { + "type": "spark", + "method": "odbc", + "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), + "port": 443, + "driver": os.getenv("ODBC_DRIVER"), + "connection_str_extra": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3', + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + @pytest.fixture(autouse=True) def skip_by_profile_type(request): diff --git a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py index 478329668..7e05290ad 100644 --- a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py +++ b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py @@ -21,7 +21,7 @@ def test_run_incremental_fail_on_schema_change(self, project): assert "Compilation Error" in results_two[1].message -@pytest.mark.skip_profile("databricks_sql_endpoint") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_http_odbc") class TestAppendOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): @@ -32,7 +32,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc" class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py index b05fcb279..eb447ee4f 100644 --- a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py @@ -55,7 +55,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["default_append", "expected_append"]) @pytest.mark.skip_profile( - "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_default_append(self, project): self.run_and_test(project) @@ -77,7 +77,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["insert_overwrite_partitions", "expected_upsert"]) @pytest.mark.skip_profile( - "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_insert_overwrite(self, project): self.run_and_test(project) @@ -103,7 +103,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"]) @pytest.mark.skip_profile( - "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_delta_strategies(self, project): self.run_and_test(project) diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index e35a13a64..0b5b80e6e 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -183,7 +183,7 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" ) class TestSparkTableConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual @@ -198,7 +198,7 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" ) class TestSparkViewConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual @@ -213,7 +213,7 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" ) class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py index cd798d1da..60125be09 100644 --- a/tests/functional/adapter/test_python_model.py +++ b/tests/functional/adapter/test_python_model.py @@ -8,12 +8,12 @@ from dbt.tests.adapter.python_model.test_spark import BasePySparkTests -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestPythonModelSpark(BasePythonModelTests): pass -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestPySpark(BasePySparkTests): def test_different_dataframes(self, project): """ @@ -33,7 +33,7 @@ def test_different_dataframes(self, project): assert len(results) == 3 -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestPythonIncrementalModelSpark(BasePythonIncrementalTests): @pytest.fixture(scope="class") def project_config_update(self): @@ -78,7 +78,7 @@ def model(dbt, spark): """ -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestChangingSchemaSpark: """ Confirm that we can setup a spot instance and parse required packages into the Databricks job. diff --git a/tests/functional/adapter/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py index e78bd4f71..91f52e4b4 100644 --- a/tests/functional/adapter/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -7,7 +7,7 @@ ) -@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint") +@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc") class TestSparkStoreTestFailures(StoreTestFailuresBase): @pytest.fixture(scope="class") def project_config_update(self): From 0e91dbd46c5c5dcdc6fd76934411ad2fa2132cb2 Mon Sep 17 00:00:00 2001 From: nilan3 Date: Tue, 27 Aug 2024 12:46:37 +0100 Subject: [PATCH 02/10] clean up --- dbt/adapters/spark/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index f5dcbe6a1..7a1639816 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -155,7 +155,7 @@ def __post_init__(self) -> None: f"ImportError({e.msg})" ) from e - if self.method != SparkConnectionMethod.SESSION and self.host is not None: + if self.method != SparkConnectionMethod.SESSION: self.host = self.host.rstrip("/") self.server_side_parameters = { From 8e9718545a3560cbacf5b3cfe3061856dd80d42e Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 29 Aug 2024 16:56:16 -0700 Subject: [PATCH 03/10] fix typo in test_incremental_on_schema_change.py --- .../adapter/incremental/test_incremental_on_schema_change.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py index 7e05290ad..6f881697c 100644 --- a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py +++ b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py @@ -32,7 +32,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc" +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc") class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): From a1e68965cb1c196859c5a368cc950532f44297ef Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 29 Aug 2024 17:11:48 -0700 Subject: [PATCH 04/10] fix formatting --- tests/conftest.py | 1 + .../test_incremental_strategies.py | 6 +++++- tests/functional/adapter/test_constraints.py | 18 +++++++++++++++--- tests/functional/adapter/test_python_model.py | 16 ++++++++++++---- .../adapter/test_store_test_failures.py | 4 +++- 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fe4174f5c..4656963cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -103,6 +103,7 @@ def spark_session_target(): "method": "session", } + def spark_http_odbc_target(): return { "type": "spark", diff --git a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py index eb447ee4f..a44a1d23e 100644 --- a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py @@ -103,7 +103,11 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"]) @pytest.mark.skip_profile( - "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" + "apache_spark", + "databricks_http_cluster", + "databricks_sql_endpoint", + "spark_session", + "spark_http_odbc", ) def test_delta_strategies(self, project): self.run_and_test(project) diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index 0b5b80e6e..f33359262 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -183,7 +183,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_http_odbc", ) class TestSparkTableConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual @@ -198,7 +202,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_http_odbc", ) class TestSparkViewConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual @@ -213,7 +221,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_http_odbc", ) class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py index 60125be09..50132b883 100644 --- a/tests/functional/adapter/test_python_model.py +++ b/tests/functional/adapter/test_python_model.py @@ -8,12 +8,16 @@ from dbt.tests.adapter.python_model.test_spark import BasePySparkTests -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestPythonModelSpark(BasePythonModelTests): pass -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestPySpark(BasePySparkTests): def test_different_dataframes(self, project): """ @@ -33,7 +37,9 @@ def test_different_dataframes(self, project): assert len(results) == 3 -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestPythonIncrementalModelSpark(BasePythonIncrementalTests): @pytest.fixture(scope="class") def project_config_update(self): @@ -78,7 +84,9 @@ def model(dbt, spark): """ -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc" +) class TestChangingSchemaSpark: """ Confirm that we can setup a spot instance and parse required packages into the Databricks job. diff --git a/tests/functional/adapter/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py index 91f52e4b4..3d8a4c192 100644 --- a/tests/functional/adapter/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -7,7 +7,9 @@ ) -@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc") +@pytest.mark.skip_profile( + "spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc" +) class TestSparkStoreTestFailures(StoreTestFailuresBase): @pytest.fixture(scope="class") def project_config_update(self): From 7438bc656a0d240819ed9d791b59ce2ea013626f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Tue, 10 Sep 2024 17:58:57 -0400 Subject: [PATCH 05/10] changelog --- .changes/unreleased/Features-20240910-175846.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20240910-175846.yaml diff --git a/.changes/unreleased/Features-20240910-175846.yaml b/.changes/unreleased/Features-20240910-175846.yaml new file mode 100644 index 000000000..7a877567b --- /dev/null +++ b/.changes/unreleased/Features-20240910-175846.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support custom ODBC connection parameters via `connection_str_extra` config +time: 2024-09-10T17:58:46.141332-04:00 +custom: + Author: colin-rogers-dbt jpoley + Issue: "1092" From fc3d0ed053d8ddde9e94e9fadac07673b8813d40 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 11 Sep 2024 10:53:21 -0700 Subject: [PATCH 06/10] Add unit test and refactor unit test fixtures --- dbt/adapters/spark/connections.py | 12 +- tests/conftest.py | 2 +- tests/unit/conftest.py | 1 + tests/unit/fixtures/__init__.py | 0 tests/unit/fixtures/profiles.py | 175 ++++++++++++++++++++++ tests/unit/test_adapter.py | 231 +++++++++--------------------- 6 files changed, 247 insertions(+), 174 deletions(-) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/fixtures/__init__.py create mode 100644 tests/unit/fixtures/profiles.py diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 7a1639816..d9b615ecb 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -78,7 +78,7 @@ class SparkCredentials(Credentials): auth: Optional[str] = None kerberos_service_name: Optional[str] = None organization: str = "0" - connection_str_extra: Optional[str] = None + connection_string_suffix: Optional[str] = None connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -484,11 +484,11 @@ def open(cls, connection: Connection) -> Connection: http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) - elif creds.connection_str_extra is not None: - required_fields = ["driver", "host", "port", "connection_str_extra"] + elif creds.connection_string_suffix is not None: + required_fields = ["driver", "host", "port", "connection_string_suffix"] else: raise DbtConfigError( - "Either `cluster`, `endpoint`, `connection_str_extra` must set when" + "Either `cluster`, `endpoint`, `connection_string_suffix` must set when" " using the odbc method to connect to Spark" ) @@ -527,8 +527,8 @@ def open(cls, connection: Connection) -> Connection: LCaseSspKeyName=0 if ssp else 1, **ssp, ) - if creds.connection_str_extra is not None: - connection_str = connection_str + ";" + creds.connection_str_extra + if creds.connection_string_suffix is not None: + connection_str = connection_str + ";" + creds.connection_string_suffix conn = pyodbc.connect(connection_str, autocommit=True) handle = PyodbcConnectionWrapper(conn) diff --git a/tests/conftest.py b/tests/conftest.py index 4656963cc..09b31f406 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,7 +111,7 @@ def spark_http_odbc_target(): "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), "port": 443, "driver": os.getenv("ODBC_DRIVER"), - "connection_str_extra": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3', + "connection_string_suffix": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3', "connect_retries": 3, "connect_timeout": 5, "retry_all": True, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..c3b000352 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1 @@ +from .fixtures.profiles import * diff --git a/tests/unit/fixtures/__init__.py b/tests/unit/fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/fixtures/profiles.py b/tests/unit/fixtures/profiles.py new file mode 100644 index 000000000..7c125a00f --- /dev/null +++ b/tests/unit/fixtures/profiles.py @@ -0,0 +1,175 @@ +import pytest +from dbt.config import RuntimeConfig + +from tests.unit.utils import config_from_parts_or_dicts + + +@pytest.fixture(scope="session", autouse=True) +def base_project_cfg(): + return { + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": False, + }, + "config-version": 2, + } + + +@pytest.fixture(scope="session", autouse=True) +def target_http(base_project_cfg) -> RuntimeConfig: + config = config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "http", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 443, + "token": "abc123", + "organization": "0123456789", + "cluster": "01234-23423-coffeetime", + "server_side_parameters": {"spark.driver.memory": "4g"}, + } + }, + "target": "test", + }, + ) + return config + + +@pytest.fixture(scope="session", autouse=True) +def target_thrift(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "thrift", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 10001, + "user": "dbt", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_thrift_kerberos(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "thrift", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 10001, + "user": "dbt", + "auth": "KERBEROS", + "kerberos_service_name": "hive", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_use_ssl_thrift(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "thrift", + "use_ssl": True, + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 10001, + "user": "dbt", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_odbc_cluster(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "odbc", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 443, + "token": "abc123", + "organization": "0123456789", + "cluster": "01234-23423-coffeetime", + "driver": "Simba", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_odbc_sql_endpoint(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "odbc", + "schema": "analytics", + "host": "myorg.sparkhost.com", + "port": 443, + "token": "abc123", + "endpoint": "012342342393920a", + "driver": "Simba", + } + }, + "target": "test", + }, + ) + + +@pytest.fixture(scope="session", autouse=True) +def target_odbc_with_extra_conn(base_project_cfg): + return config_from_parts_or_dicts( + base_project_cfg, + { + "outputs": { + "test": { + "type": "spark", + "method": "odbc", + "host": "myorg.sparkhost.com", + "schema": "analytics", + "port": 443, + "driver": "Simba", + "connection_string_suffix": "someExtraValues", + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + }, + "target": "test", + }, + ) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 54e9f0158..6a40d9cca 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,153 +1,41 @@ import unittest +import pytest from multiprocessing import get_context from unittest import mock -import dbt.flags as flags from dbt.exceptions import DbtRuntimeError from agate import Row from pyhive import hive from dbt.adapters.spark import SparkAdapter, SparkRelation from .utils import config_from_parts_or_dicts +pytest.mark.usefixtures("target_http") -class TestSparkAdapter(unittest.TestCase): - def setUp(self): - flags.STRICT_MODE = False - - self.project_cfg = { - "name": "X", - "version": "0.1", - "profile": "test", - "project-root": "/tmp/dbt/does-not-exist", - "quoting": { - "identifier": False, - "schema": False, - }, - "config-version": 2, - } - - def _get_target_http(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "http", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 443, - "token": "abc123", - "organization": "0123456789", - "cluster": "01234-23423-coffeetime", - "server_side_parameters": {"spark.driver.memory": "4g"}, - } - }, - "target": "test", - }, - ) - def _get_target_thrift(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "thrift", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 10001, - "user": "dbt", - } - }, - "target": "test", - }, - ) - - def _get_target_thrift_kerberos(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "thrift", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 10001, - "user": "dbt", - "auth": "KERBEROS", - "kerberos_service_name": "hive", - } - }, - "target": "test", - }, - ) - - def _get_target_use_ssl_thrift(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "thrift", - "use_ssl": True, - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 10001, - "user": "dbt", - } - }, - "target": "test", - }, - ) - - def _get_target_odbc_cluster(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "odbc", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 443, - "token": "abc123", - "organization": "0123456789", - "cluster": "01234-23423-coffeetime", - "driver": "Simba", - } - }, - "target": "test", - }, - ) - - def _get_target_odbc_sql_endpoint(self, project): - return config_from_parts_or_dicts( - project, - { - "outputs": { - "test": { - "type": "spark", - "method": "odbc", - "schema": "analytics", - "host": "myorg.sparkhost.com", - "port": 443, - "token": "abc123", - "endpoint": "012342342393920a", - "driver": "Simba", - } - }, - "target": "test", - }, - ) +class TestSparkAdapter(unittest.TestCase): + @pytest.fixture(autouse=True) + def set_up_fixtures( + self, + target_http, + target_odbc_with_extra_conn, + target_thrift, + target_thrift_kerberos, + target_odbc_sql_endpoint, + target_odbc_cluster, + target_use_ssl_thrift, + base_project_cfg, + ): + self.base_project_cfg = base_project_cfg + self.target_http = target_http + self.target_odbc_with_extra_conn = target_odbc_with_extra_conn + self.target_odbc_sql_endpoint = target_odbc_sql_endpoint + self.target_odbc_cluster = target_odbc_cluster + self.target_thrift = target_thrift + self.target_thrift_kerberos = target_thrift_kerberos + self.target_use_ssl_thrift = target_use_ssl_thrift def test_http_connection(self): - config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_http, get_context("spawn")) def hive_http_connect(thrift_transport, configuration): self.assertEqual(thrift_transport.scheme, "https") @@ -171,7 +59,7 @@ def hive_http_connect(thrift_transport, configuration): self.assertIsNone(connection.credentials.database) def test_thrift_connection(self): - config = self._get_target_thrift(self.project_cfg) + config = self.target_thrift adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect( @@ -195,8 +83,7 @@ def hive_thrift_connect( self.assertIsNone(connection.credentials.database) def test_thrift_ssl_connection(self): - config = self._get_target_use_ssl_thrift(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_use_ssl_thrift, get_context("spawn")) def hive_thrift_connect(thrift_transport, configuration): self.assertIsNotNone(thrift_transport) @@ -215,8 +102,7 @@ def hive_thrift_connect(thrift_transport, configuration): self.assertIsNone(connection.credentials.database) def test_thrift_connection_kerberos(self): - config = self._get_target_thrift_kerberos(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_thrift_kerberos, get_context("spawn")) def hive_thrift_connect( host, port, username, auth, kerberos_service_name, password, configuration @@ -239,8 +125,7 @@ def hive_thrift_connect( self.assertIsNone(connection.credentials.database) def test_odbc_cluster_connection(self): - config = self._get_target_odbc_cluster(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_odbc_cluster, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -266,8 +151,7 @@ def pyodbc_connect(connection_str, autocommit): self.assertIsNone(connection.credentials.database) def test_odbc_endpoint_connection(self): - config = self._get_target_odbc_sql_endpoint(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_odbc_sql_endpoint, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -291,6 +175,26 @@ def pyodbc_connect(connection_str, autocommit): self.assertEqual(connection.credentials.schema, "analytics") self.assertIsNone(connection.credentials.database) + def test_odbc_with_extra_connection_string(self): + adapter = SparkAdapter(self.target_odbc_with_extra_conn, get_context("spawn")) + + def pyodbc_connect(connection_str, autocommit): + self.assertTrue(autocommit) + self.assertIn("driver=simba;", connection_str.lower()) + self.assertIn("port=443;", connection_str.lower()) + self.assertIn("host=myorg.sparkhost.com;", connection_str.lower()) + self.assertIn("someExtraValues", connection_str) + + with mock.patch( + "dbt.adapters.spark.connections.pyodbc.connect", new=pyodbc_connect + ): # noqa + connection = adapter.acquire_connection("dummy") + connection.handle # trigger lazy-load + + self.assertEqual(connection.state, "open") + self.assertIsNotNone(connection.handle) + self.assertIsNone(connection.credentials.database) + def test_parse_relation(self): self.maxDiff = None rel_type = SparkRelation.get_relation_type.Table @@ -329,8 +233,7 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + rows = SparkAdapter(self.target_http, get_context("spawn")).parse_describe_extended( relation, input_cols ) self.assertEqual(len(rows), 4) @@ -420,8 +323,7 @@ def test_parse_relation_with_integer_owner(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + rows = SparkAdapter(self.target_http, get_context("spawn")).parse_describe_extended( relation, input_cols ) @@ -458,8 +360,7 @@ def test_parse_relation_with_statistics(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] - config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + rows = SparkAdapter(self.target_http, get_context("spawn")).parse_describe_extended( relation, input_cols ) self.assertEqual(len(rows), 1) @@ -489,8 +390,7 @@ def test_parse_relation_with_statistics(self): ) def test_relation_with_database(self): - config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config, get_context("spawn")) + adapter = SparkAdapter(self.target_http, get_context("spawn")) # fine adapter.Relation.create(schema="different", identifier="table") with self.assertRaises(DbtRuntimeError): @@ -516,7 +416,7 @@ def test_profile_with_database(self): "target": "test", } with self.assertRaises(DbtRuntimeError): - config_from_parts_or_dicts(self.project_cfg, profile) + config_from_parts_or_dicts(self.base_project_cfg, profile) def test_profile_with_cluster_and_sql_endpoint(self): profile = { @@ -536,7 +436,7 @@ def test_profile_with_cluster_and_sql_endpoint(self): "target": "test", } with self.assertRaises(DbtRuntimeError): - config_from_parts_or_dicts(self.project_cfg, profile) + config_from_parts_or_dicts(self.base_project_cfg, profile) def test_parse_columns_from_information_with_table_type_and_delta_provider(self): self.maxDiff = None @@ -570,10 +470,9 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) schema="default_schema", identifier="mytable", type=rel_type, information=information ) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( - relation - ) + columns = SparkAdapter( + self.target_http, get_context("spawn") + ).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( columns[0].to_column_dict(omit_none=False), @@ -657,10 +556,9 @@ def test_parse_columns_from_information_with_view_type(self): schema="default_schema", identifier="myview", type=rel_type, information=information ) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( - relation - ) + columns = SparkAdapter( + self.target_http, get_context("spawn") + ).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( columns[1].to_column_dict(omit_none=False), @@ -725,10 +623,9 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel schema="default_schema", identifier="mytable", type=rel_type, information=information ) - config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( - relation - ) + columns = SparkAdapter( + self.target_http, get_context("spawn") + ).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) self.assertEqual( From e8ad85c551387fb951e49f924622645df0add403 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 11 Sep 2024 11:03:00 -0700 Subject: [PATCH 07/10] update changie --- .changes/unreleased/Features-20240910-175846.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changes/unreleased/Features-20240910-175846.yaml b/.changes/unreleased/Features-20240910-175846.yaml index 7a877567b..29cb58798 100644 --- a/.changes/unreleased/Features-20240910-175846.yaml +++ b/.changes/unreleased/Features-20240910-175846.yaml @@ -1,5 +1,5 @@ kind: Features -body: Support custom ODBC connection parameters via `connection_str_extra` config +body: Support custom ODBC connection parameters via `connection_string_suffix` config time: 2024-09-10T17:58:46.141332-04:00 custom: Author: colin-rogers-dbt jpoley From 051fe05c62fc75c0a704c25acd0d2587a80bf029 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 11 Sep 2024 11:03:48 -0700 Subject: [PATCH 08/10] update changie --- .changes/unreleased/Features-20240910-175846.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changes/unreleased/Features-20240910-175846.yaml b/.changes/unreleased/Features-20240910-175846.yaml index 29cb58798..68ef8551e 100644 --- a/.changes/unreleased/Features-20240910-175846.yaml +++ b/.changes/unreleased/Features-20240910-175846.yaml @@ -2,5 +2,5 @@ kind: Features body: Support custom ODBC connection parameters via `connection_string_suffix` config time: 2024-09-10T17:58:46.141332-04:00 custom: - Author: colin-rogers-dbt jpoley + Author: colin-rogers-dbt jpoley nilan3 Issue: "1092" From c971e5276bfabcc8272f6833cde68db90c8052b1 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 11 Sep 2024 14:37:21 -0700 Subject: [PATCH 09/10] remove holdover code --- tests/unit/test_adapter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 6a40d9cca..323e82a11 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -9,8 +9,6 @@ from dbt.adapters.spark import SparkAdapter, SparkRelation from .utils import config_from_parts_or_dicts -pytest.mark.usefixtures("target_http") - class TestSparkAdapter(unittest.TestCase): @pytest.fixture(autouse=True) From 3736d8e72fc0d451f04ee953e66471a7332c431b Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 11 Sep 2024 14:38:54 -0700 Subject: [PATCH 10/10] remove dbt-core ref --- tests/unit/fixtures/profiles.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/fixtures/profiles.py b/tests/unit/fixtures/profiles.py index 7c125a00f..c5f24581e 100644 --- a/tests/unit/fixtures/profiles.py +++ b/tests/unit/fixtures/profiles.py @@ -1,5 +1,4 @@ import pytest -from dbt.config import RuntimeConfig from tests.unit.utils import config_from_parts_or_dicts @@ -20,7 +19,7 @@ def base_project_cfg(): @pytest.fixture(scope="session", autouse=True) -def target_http(base_project_cfg) -> RuntimeConfig: +def target_http(base_project_cfg): config = config_from_parts_or_dicts( base_project_cfg, {