From 4e7f5d53d73c5a93903401c79637f2e2cb13339f Mon Sep 17 00:00:00 2001 From: vakarisbk Date: Fri, 8 Sep 2023 15:15:32 +0300 Subject: [PATCH] Add spark-connect connection method --- .../unreleased/Features-20231004-191452.yaml | 6 ++ dagger/run_dbt_spark_tests.py | 29 ++++++++- dbt/adapters/spark/connections.py | 62 ++++++++++++++++++- dbt/adapters/spark/session.py | 35 +++++++++-- requirements.txt | 3 + setup.py | 12 +++- tests/conftest.py | 12 ++-- .../adapter/dbt_clone/test_dbt_clone.py | 2 +- .../test_incremental_merge_exclude_columns.py | 2 +- .../test_incremental_on_schema_change.py | 4 +- .../test_incremental_predicates.py | 4 +- .../incremental/test_incremental_unique_id.py | 2 +- .../test_incremental_strategies.py | 6 +- .../adapter/persist_docs/test_persist_docs.py | 6 +- tests/functional/adapter/test_basic.py | 4 +- tests/functional/adapter/test_constraints.py | 44 +++++++++---- tests/functional/adapter/test_grants.py | 10 +-- tests/functional/adapter/test_python_model.py | 16 +++-- .../adapter/test_store_test_failures.py | 5 +- tests/functional/adapter/utils/test_utils.py | 9 ++- 20 files changed, 220 insertions(+), 53 deletions(-) create mode 100644 .changes/unreleased/Features-20231004-191452.yaml diff --git a/.changes/unreleased/Features-20231004-191452.yaml b/.changes/unreleased/Features-20231004-191452.yaml new file mode 100644 index 000000000..7a9cfa0bd --- /dev/null +++ b/.changes/unreleased/Features-20231004-191452.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add support for Spark Connect +time: 2023-10-04T19:14:52.858895+03:00 +custom: + Author: vakarisbk + Issue: "899" diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 436cb1e92..315336344 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -85,6 +85,29 @@ def get_spark_container(client: dagger.Client) -> (dagger.Service, str): return spark_ctr, "spark_db" +def get_spark_connect_container(client: dagger.Client) -> (dagger.Container, str): + spark_ctr_base = ( + client.container() + .from_("spark:3.5.0-scala2.12-java17-ubuntu") + .with_exec( + [ + "/opt/spark/bin/spark-submit", + "--class", + "org.apache.spark.sql.connect.service.SparkConnectServer", + "--conf", + "spark.sql.catalogImplementation=hive", + "--packages", + "org.apache.spark:spark-connect_2.12:3.5.0", + "--conf", + "spark.jars.ivy=/tmp", + ] + ) + .with_exposed_port(15002) + .as_service() + ) + return spark_ctr_base, "localhost" + + async def test_spark(test_args): async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as client: test_profile = test_args.profile @@ -133,7 +156,11 @@ async def test_spark(test_args): ) elif test_profile == "spark_session": - tst_container = tst_container.with_exec(["pip", "install", "pyspark"]) + tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"]) + + elif test_profile == "spark_connect": + spark_ctr, spark_host = get_spark_connect_container(client) + tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"]) tst_container = tst_container.with_(env_variables(TESTING_ENV_VARS)) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 83048f921..8fb9b04ba 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -60,6 +60,7 @@ class SparkConnectionMethod(StrEnum): HTTP = "http" ODBC = "odbc" SESSION = "session" + CONNECT = "connect" @dataclass @@ -154,6 +155,21 @@ def __post_init__(self) -> None: f"ImportError({e.msg})" ) from e + if self.method == SparkConnectionMethod.CONNECT: + try: + import pyspark # noqa: F401 F811 + import grpc # noqa: F401 + import pyarrow # noqa: F401 + import pandas # noqa: F401 + except ImportError as e: + raise dbt.exceptions.DbtRuntimeError( + f"{self.method} connection method requires " + "additional dependencies. \n" + "Install the additional required dependencies with " + "`pip install dbt-spark[connect]`\n\n" + f"ImportError({e.msg})" + ) from e + if self.method != SparkConnectionMethod.SESSION: self.host = self.host.rstrip("/") @@ -524,8 +540,52 @@ def open(cls, connection: Connection) -> Connection: SessionConnectionWrapper, ) + # Pass session type (session or connect) into SessionConnectionWrapper + handle = SessionConnectionWrapper( + Connection( + conn_method=creds.method, + conn_url="localhost", + server_side_parameters=creds.server_side_parameters, + ) + ) + elif SparkConnectionMethod.CONNECT: + # Create the url + + host = creds.host + port = creds.port + token = creds.token + use_ssl = creds.use_ssl + user = creds.user + + # URL Format: sc://localhost:15002/;user_id=str;token=str;use_ssl=bool + if not host.startswith("sc://"): + base_url = f"sc://{host}" + base_url += f":{str(port)}" + + url_extensions = [] + if user: + url_extensions.append(f"user_id={user}") + if use_ssl: + url_extensions.append(f"use_ssl={use_ssl}") + if token: + url_extensions.append(f"token={token}") + + conn_url = base_url + ";".join(url_extensions) + + logger.debug("connection url: {}".format(conn_url)) + + from .session import ( # noqa: F401 + Connection, + SessionConnectionWrapper, + ) + + # Pass session type (session or connect) into SessionConnectionWrapper handle = SessionConnectionWrapper( - Connection(server_side_parameters=creds.server_side_parameters) + Connection( + conn_method=creds.method, + conn_url=conn_url, + server_side_parameters=creds.server_side_parameters, + ) ) else: raise DbtConfigError(f"invalid credential method: {creds.method}") diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 7a6982e50..0d4f31881 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -6,7 +6,7 @@ from types import TracebackType from typing import Any, Dict, List, Optional, Tuple, Union, Sequence -from dbt.adapters.spark.connections import SparkConnectionWrapper +from dbt.adapters.spark.connections import SparkConnectionMethod, SparkConnectionWrapper from dbt.adapters.events.logging import AdapterLogger from dbt_common.utils.encoding import DECIMALS from dbt_common.exceptions import DbtRuntimeError @@ -27,9 +27,17 @@ class Cursor: https://github.com/mkleehammer/pyodbc/wiki/Cursor """ - def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + *, + conn_method: SparkConnectionMethod, + conn_url: str, + server_side_parameters: Optional[Dict[str, Any]] = None, + ) -> None: self._df: Optional[DataFrame] = None self._rows: Optional[List[Row]] = None + self.conn_method: SparkConnectionMethod = conn_method + self.conn_url: str = conn_url self.server_side_parameters = server_side_parameters or {} def __enter__(self) -> Cursor: @@ -113,12 +121,15 @@ def execute(self, sql: str, *parameters: Any) -> None: if len(parameters) > 0: sql = sql % parameters - builder = SparkSession.builder.enableHiveSupport() + builder = SparkSession.builder for parameter, value in self.server_side_parameters.items(): builder = builder.config(parameter, value) - spark_session = builder.getOrCreate() + if self.conn_method == SparkConnectionMethod.CONNECT: + spark_session = builder.remote(self.conn_url).getOrCreate() + elif self.conn_method == SparkConnectionMethod.SESSION: + spark_session = builder.enableHiveSupport().getOrCreate() try: self._df = spark_session.sql(sql) @@ -175,7 +186,15 @@ class Connection: https://github.com/mkleehammer/pyodbc/wiki/Connection """ - def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None: + def __init__( + self, + *, + conn_method: SparkConnectionMethod, + conn_url: str, + server_side_parameters: Optional[Dict[Any, str]] = None, + ) -> None: + self.conn_method = conn_method + self.conn_url = conn_url self.server_side_parameters = server_side_parameters or {} def cursor(self) -> Cursor: @@ -187,7 +206,11 @@ def cursor(self) -> Cursor: out : Cursor The cursor. """ - return Cursor(server_side_parameters=self.server_side_parameters) + return Cursor( + conn_method=self.conn_method, + conn_url=self.conn_url, + server_side_parameters=self.server_side_parameters, + ) class SessionConnectionWrapper(SparkConnectionWrapper): diff --git a/requirements.txt b/requirements.txt index 18ccc77fd..a690377cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,8 @@ sqlparams>=3.0.0 thrift>=0.13.0 sqlparse>=0.4.2 # not directly required, pinned by Snyk to avoid a vulnerability +#spark-connect +pyspark[connect]>=3.5.0,<4 + types-PyYAML types-python-dateutil diff --git a/setup.py b/setup.py index 2d6e00e53..469370cc9 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,16 @@ def _get_plugin_version_dict(): "thrift>=0.11.0,<0.17.0", ] session_extras = ["pyspark>=3.0.0,<4.0.0"] -all_extras = odbc_extras + pyhive_extras + session_extras +connect_extras = [ + "pyspark==3.5.0", + "pandas>=1.05", + "pyarrow>=4.0.0", + "numpy>=1.15", + "grpcio>=1.46,<1.57", + "grpcio-status>=1.46,<1.57", + "googleapis-common-protos==1.56.4", +] +all_extras = odbc_extras + pyhive_extras + session_extras + connect_extras setup( name=package_name, @@ -71,6 +80,7 @@ def _get_plugin_version_dict(): "ODBC": odbc_extras, "PyHive": pyhive_extras, "session": session_extras, + "connect": connect_extras, "all": all_extras, }, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index efba41a5f..eb590f251 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,8 @@ def dbt_profile_target(request): target = databricks_http_cluster_target() elif profile_type == "spark_session": target = spark_session_target() + elif profile_type == "spark_connect": + target = spark_connect_target() else: raise ValueError(f"Invalid profile type '{profile_type}'") return target @@ -95,11 +97,11 @@ def databricks_http_cluster_target(): def spark_session_target(): - return { - "type": "spark", - "host": "localhost", - "method": "session", - } + return {"type": "spark", "host": "localhost", "method": "session"} + + +def spark_connect_target(): + return {"type": "spark", "host": "localhost", "port": 15002, "method": "connect"} @pytest.fixture(autouse=True) diff --git a/tests/functional/adapter/dbt_clone/test_dbt_clone.py b/tests/functional/adapter/dbt_clone/test_dbt_clone.py index a5e8d70e0..412327d8b 100644 --- a/tests/functional/adapter/dbt_clone/test_dbt_clone.py +++ b/tests/functional/adapter/dbt_clone/test_dbt_clone.py @@ -14,7 +14,7 @@ ) -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestSparkBigqueryClonePossible(BaseClonePossible): @pytest.fixture(scope="class") def models(self): diff --git a/tests/functional/adapter/incremental/test_incremental_merge_exclude_columns.py b/tests/functional/adapter/incremental/test_incremental_merge_exclude_columns.py index 7560b25ce..168e56f02 100644 --- a/tests/functional/adapter/incremental/test_incremental_merge_exclude_columns.py +++ b/tests/functional/adapter/incremental/test_incremental_merge_exclude_columns.py @@ -5,7 +5,7 @@ ) -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestMergeExcludeColumns(BaseMergeExcludeColumns): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py index 478329668..b3ae9d145 100644 --- a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py +++ b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py @@ -32,7 +32,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_connect") class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): @@ -45,7 +45,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestDeltaOnSchemaChange(BaseIncrementalOnSchemaChangeSetup): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental/test_incremental_predicates.py b/tests/functional/adapter/incremental/test_incremental_predicates.py index 52c01a747..fe631d9d5 100644 --- a/tests/functional/adapter/incremental/test_incremental_predicates.py +++ b/tests/functional/adapter/incremental/test_incremental_predicates.py @@ -27,7 +27,7 @@ """ -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestIncrementalPredicatesMergeSpark(BaseIncrementalPredicates): @pytest.fixture(scope="class") def project_config_update(self): @@ -46,7 +46,7 @@ def models(self): } -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestPredicatesMergeSpark(BaseIncrementalPredicates): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental/test_incremental_unique_id.py b/tests/functional/adapter/incremental/test_incremental_unique_id.py index de8cb652c..f431691b2 100644 --- a/tests/functional/adapter/incremental/test_incremental_unique_id.py +++ b/tests/functional/adapter/incremental/test_incremental_unique_id.py @@ -2,7 +2,7 @@ from dbt.tests.adapter.incremental.test_incremental_unique_id import BaseIncrementalUniqueKey -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestUniqueKeySpark(BaseIncrementalUniqueKey): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py index b05fcb279..3273305e8 100644 --- a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py @@ -103,7 +103,11 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"]) @pytest.mark.skip_profile( - "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "apache_spark", + "databricks_http_cluster", + "databricks_sql_endpoint", + "spark_session", + "spark_connect", ) def test_delta_strategies(self, project): self.run_and_test(project) diff --git a/tests/functional/adapter/persist_docs/test_persist_docs.py b/tests/functional/adapter/persist_docs/test_persist_docs.py index ee02e5ef8..5cf89c48b 100644 --- a/tests/functional/adapter/persist_docs/test_persist_docs.py +++ b/tests/functional/adapter/persist_docs/test_persist_docs.py @@ -15,7 +15,7 @@ ) -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestPersistDocsDeltaTable: @pytest.fixture(scope="class") def models(self): @@ -78,7 +78,7 @@ def test_delta_comments(self, project): assert result[2].startswith("Some stuff here and then a call to") -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestPersistDocsDeltaView: @pytest.fixture(scope="class") def models(self): @@ -120,7 +120,7 @@ def test_delta_comments(self, project): assert result[2] is None -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestPersistDocsMissingColumn: @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index 072d211d6..6a7f84565 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -50,7 +50,7 @@ class TestGenericTestsSpark(BaseGenericTests): # These tests were not enabled in the dbtspec files, so skipping here. # Error encountered was: Error running query: java.lang.ClassNotFoundException: delta.DefaultSource -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestSnapshotCheckColsSpark(BaseSnapshotCheckCols): @pytest.fixture(scope="class") def project_config_update(self): @@ -66,7 +66,7 @@ def project_config_update(self): # These tests were not enabled in the dbtspec files, so skipping here. # Error encountered was: Error running query: java.lang.ClassNotFoundException: delta.DefaultSource -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestSnapshotTimestampSpark(BaseSnapshotTimestamp): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index 41b50ef81..50c282487 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -147,7 +147,9 @@ def data_types(self, int_type, schema_int_type, string_type, schema_string_type) ] -@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster") +@pytest.mark.skip_profile( + "spark_session", "apache_spark", "databricks_http_cluster", "spark_connect" +) class TestSparkTableConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseTableConstraintsColumnsEqual): @pytest.fixture(scope="class") def models(self): @@ -158,7 +160,9 @@ def models(self): } -@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster") +@pytest.mark.skip_profile( + "spark_session", "apache_spark", "databricks_http_cluster", "spark_connect" +) class TestSparkViewConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseViewConstraintsColumnsEqual): @pytest.fixture(scope="class") def models(self): @@ -169,7 +173,9 @@ def models(self): } -@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster") +@pytest.mark.skip_profile( + "spark_session", "apache_spark", "databricks_http_cluster", "spark_connect" +) class TestSparkIncrementalConstraintsColumnsEqualPyodbc( PyodbcSetup, BaseIncrementalConstraintsColumnsEqual ): @@ -183,7 +189,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_connect", ) class TestSparkTableConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual @@ -198,7 +208,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_connect", ) class TestSparkViewConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual @@ -213,7 +227,11 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", + "apache_spark", + "databricks_sql_endpoint", + "databricks_cluster", + "spark_connect", ) class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual @@ -241,7 +259,7 @@ def expected_sql(self): return _expected_sql_spark -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestSparkTableConstraintsDdlEnforcement( BaseSparkConstraintsDdlEnforcementSetup, BaseConstraintsRuntimeDdlEnforcement ): @@ -254,7 +272,7 @@ def models(self): } -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestSparkIncrementalConstraintsDdlEnforcement( BaseSparkConstraintsDdlEnforcementSetup, BaseIncrementalConstraintsRuntimeDdlEnforcement ): @@ -267,7 +285,9 @@ def models(self): } -@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster") +@pytest.mark.skip_profile( + "spark_session", "apache_spark", "databricks_http_cluster", "spark_connect" +) class TestSparkConstraintQuotedColumn(PyodbcSetup, BaseConstraintQuotedColumn): @pytest.fixture(scope="class") def models(self): @@ -326,7 +346,7 @@ def assert_expected_error_messages(self, error_message, expected_error_messages) assert any(msg in error_message for msg in expected_error_messages) -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestSparkTableConstraintsRollback( BaseSparkConstraintsRollbackSetup, BaseConstraintsRollback ): @@ -345,7 +365,7 @@ def expected_color(self): return "red" -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestSparkIncrementalConstraintsRollback( BaseSparkConstraintsRollbackSetup, BaseIncrementalConstraintsRollback ): @@ -362,7 +382,7 @@ def models(self): # TODO: Like the tests above, this does test that model-level constraints don't # result in errors, but it does not verify that they are actually present in # Spark and that the ALTER TABLE statement actually ran. -@pytest.mark.skip_profile("spark_session", "apache_spark") +@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect") class TestSparkModelConstraintsRuntimeEnforcement(BaseModelConstraintsRuntimeEnforcement): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/test_grants.py b/tests/functional/adapter/test_grants.py index 1b1a005ad..2d2921a41 100644 --- a/tests/functional/adapter/test_grants.py +++ b/tests/functional/adapter/test_grants.py @@ -6,7 +6,7 @@ from dbt.tests.adapter.grants.test_snapshot_grants import BaseSnapshotGrants -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestModelGrantsSpark(BaseModelGrants): def privilege_grantee_name_overrides(self): # insert --> modify @@ -18,7 +18,7 @@ def privilege_grantee_name_overrides(self): } -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestIncrementalGrantsSpark(BaseIncrementalGrants): @pytest.fixture(scope="class") def project_config_update(self): @@ -30,7 +30,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestSeedGrantsSpark(BaseSeedGrants): # seeds in dbt-spark are currently "full refreshed," in such a way that # the grants are not carried over @@ -39,7 +39,7 @@ def seeds_support_partial_refresh(self): return False -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestSnapshotGrantsSpark(BaseSnapshotGrants): @pytest.fixture(scope="class") def project_config_update(self): @@ -51,7 +51,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestInvalidGrantsSpark(BaseInvalidGrants): def grantee_does_not_exist_error(self): return "RESOURCE_DOES_NOT_EXIST" diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py index 1195cbd3e..fdf6ff1bb 100644 --- a/tests/functional/adapter/test_python_model.py +++ b/tests/functional/adapter/test_python_model.py @@ -8,17 +8,23 @@ from dbt.tests.adapter.python_model.test_spark import BasePySparkTests -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_connect" +) class TestPythonModelSpark(BasePythonModelTests): pass -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_connect" +) class TestPySpark(BasePySparkTests): pass -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_connect" +) class TestPythonIncrementalModelSpark(BasePythonIncrementalTests): @pytest.fixture(scope="class") def project_config_update(self): @@ -63,7 +69,9 @@ def model(dbt, spark): """ -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile( + "apache_spark", "spark_session", "databricks_sql_endpoint", "spark_connect" +) class TestChangingSchemaSpark: @pytest.fixture(scope="class") def models(self): diff --git a/tests/functional/adapter/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py index e78bd4f71..c1ad87106 100644 --- a/tests/functional/adapter/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -34,7 +34,7 @@ def test_store_and_assert(self, project): self.run_tests_store_failures_and_assert(project) -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestSparkStoreTestFailuresWithDelta(StoreTestFailuresBase): @pytest.fixture(scope="class") def project_config_update(self): @@ -82,7 +82,8 @@ class TestStoreTestFailuresAsProjectLevelView(basic.StoreTestFailuresAsProjectLe pass -@pytest.mark.skip_profile("spark_session") +# spark connect fails because of issue [ADAP-931] +@pytest.mark.skip_profile("spark_session", "spark_connect") class TestStoreTestFailuresAsGeneric(basic.StoreTestFailuresAsGeneric): pass diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py index 0dc526564..bac35507d 100644 --- a/tests/functional/adapter/utils/test_utils.py +++ b/tests/functional/adapter/utils/test_utils.py @@ -62,6 +62,7 @@ class TestArrayConstruct(BaseArrayConstruct): pass +@pytest.mark.skip_profile("spark_connect") class TestBoolOr(BaseBoolOr): pass @@ -80,16 +81,18 @@ class TestCurrentTimestamp(BaseCurrentTimestampNaive): pass +@pytest.mark.skip_profile("spark_connect") class TestDateAdd(BaseDateAdd): pass # this generates too much SQL to run successfully in our testing environments :( -@pytest.mark.skip_profile("apache_spark", "spark_session") +@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect") class TestDateDiff(BaseDateDiff): pass +@pytest.mark.skip_profile("spark_connect") class TestDateTrunc(BaseDateTrunc): pass @@ -139,12 +142,12 @@ class TestPosition(BasePosition): pass -@pytest.mark.skip_profile("spark_session") +@pytest.mark.skip_profile("spark_session", "spark_connect") class TestReplace(BaseReplace): pass -@pytest.mark.skip_profile("spark_session") +@pytest.mark.skip_profile("spark_session", "spark_connect") class TestRight(BaseRight): pass