diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 3b3101552..486768676 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.7.0a1 +current_version = 1.8.0a1 parse = (?P[\d]+) # major version number \.(?P[\d]+) # minor version number \.(?P[\d]+) # patch version number diff --git a/.changes/unreleased/Dependencies-20230424-230630.yaml b/.changes/unreleased/Dependencies-20230424-230630.yaml deleted file mode 100644 index 1f96daad1..000000000 --- a/.changes/unreleased/Dependencies-20230424-230630.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update tox requirement from ~=3.0 to ~=4.5" -time: 2023-04-24T23:06:30.00000Z -custom: - Author: dependabot[bot] - PR: 741 diff --git a/.changes/unreleased/Dependencies-20230424-230645.yaml b/.changes/unreleased/Dependencies-20230424-230645.yaml deleted file mode 100644 index 83e1bb44b..000000000 --- a/.changes/unreleased/Dependencies-20230424-230645.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update pyodbc requirement from ~=4.0.30 to ~=4.0.39" -time: 2023-04-24T23:06:45.00000Z -custom: - Author: dependabot[bot] - PR: 742 diff --git a/.changes/unreleased/Dependencies-20230501-231003.yaml b/.changes/unreleased/Dependencies-20230501-231003.yaml deleted file mode 100644 index b3e3a750e..000000000 --- a/.changes/unreleased/Dependencies-20230501-231003.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update pre-commit requirement from ~=2.21 to ~=3.3" -time: 2023-05-01T23:10:03.00000Z -custom: - Author: dependabot[bot] - PR: 748 diff --git a/.changes/unreleased/Dependencies-20230501-231035.yaml b/.changes/unreleased/Dependencies-20230501-231035.yaml deleted file mode 100644 index 7bbf98202..000000000 --- a/.changes/unreleased/Dependencies-20230501-231035.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update types-requests requirement from ~=2.28 to ~=2.29" -time: 2023-05-01T23:10:35.00000Z -custom: - Author: dependabot[bot] - PR: 749 diff --git a/.changes/unreleased/Dependencies-20230510-230725.yaml b/.changes/unreleased/Dependencies-20230510-230725.yaml deleted file mode 100644 index dfd04ad3b..000000000 --- a/.changes/unreleased/Dependencies-20230510-230725.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Bump mypy from 1.2.0 to 1.3.0" -time: 2023-05-10T23:07:25.00000Z -custom: - Author: dependabot[bot] - PR: 768 diff --git a/.changes/unreleased/Dependencies-20230803-224622.yaml b/.changes/unreleased/Dependencies-20230803-224622.yaml deleted file mode 100644 index 119a08e51..000000000 --- a/.changes/unreleased/Dependencies-20230803-224622.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update flake8 requirement from ~=6.0 to ~=6.1" -time: 2023-08-03T22:46:22.00000Z -custom: - Author: dependabot[bot] - PR: 849 diff --git a/.changes/unreleased/Dependencies-20230803-224626.yaml b/.changes/unreleased/Dependencies-20230803-224626.yaml deleted file mode 100644 index c8b9ef04a..000000000 --- a/.changes/unreleased/Dependencies-20230803-224626.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update pytest-xdist requirement from ~=3.2 to ~=3.3" -time: 2023-08-03T22:46:26.00000Z -custom: - Author: dependabot[bot] - PR: 851 diff --git a/.changes/unreleased/Dependencies-20230803-224629.yaml b/.changes/unreleased/Dependencies-20230803-224629.yaml deleted file mode 100644 index 6865c7c74..000000000 --- a/.changes/unreleased/Dependencies-20230803-224629.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update pytest requirement from ~=7.3 to ~=7.4" -time: 2023-08-03T22:46:29.00000Z -custom: - Author: dependabot[bot] - PR: 852 diff --git a/.changes/unreleased/Dependencies-20230804-225232.yaml b/.changes/unreleased/Dependencies-20230804-225232.yaml deleted file mode 100644 index f4a09b6b0..000000000 --- a/.changes/unreleased/Dependencies-20230804-225232.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: "Dependencies" -body: "Update pip-tools requirement from ~=6.13 to ~=7.2" -time: 2023-08-04T22:52:32.00000Z -custom: - Author: dependabot[bot] - PR: 856 diff --git a/.changes/unreleased/Features-20230707-104150.yaml b/.changes/unreleased/Features-20230707-104150.yaml deleted file mode 100644 index 183a37b45..000000000 --- a/.changes/unreleased/Features-20230707-104150.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: Features -body: Support server_side_parameters for Spark session connection method -time: 2023-07-07T10:41:50.01541+02:00 -custom: - Author: alarocca-apixio - Issue: "690" diff --git a/.changes/unreleased/Features-20230707-113337.yaml b/.changes/unreleased/Features-20230707-113337.yaml deleted file mode 100644 index de0a50fe8..000000000 --- a/.changes/unreleased/Features-20230707-113337.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: Features -body: Add server_side_parameters to HTTP connection method -time: 2023-07-07T11:33:37.794112+02:00 -custom: - Author: Fokko,JCZuurmond - Issue: "824" diff --git a/.changes/unreleased/Features-20230707-114650.yaml b/.changes/unreleased/Features-20230707-114650.yaml deleted file mode 100644 index 6f1b3d38a..000000000 --- a/.changes/unreleased/Features-20230707-114650.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: Features -body: Enforce server side parameters keys and values to be strings -time: 2023-07-07T11:46:50.390918+02:00 -custom: - Author: Fokko,JCZuurmond - Issue: "826" diff --git a/.changes/unreleased/Under the Hood-20230724-165508.yaml b/.changes/unreleased/Under the Hood-20230724-165508.yaml deleted file mode 100644 index 889484644..000000000 --- a/.changes/unreleased/Under the Hood-20230724-165508.yaml +++ /dev/null @@ -1,6 +0,0 @@ -kind: Under the Hood -body: Update stale workflow to use centralized version -time: 2023-07-24T16:55:08.096947-04:00 -custom: - Author: mikealfare - Issue: "842" diff --git a/.circleci/config.yml b/.circleci/config.yml index 71ca356cf..f2a3b6357 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -11,23 +11,24 @@ jobs: - run: tox -e flake8,unit # Turning off for now due to flaky runs of tests will turn back on at later date. - # integration-spark-session: - # environment: - # DBT_INVOCATION_ENV: circle - # docker: - # - image: godatadriven/pyspark:3.1 - # steps: - # - checkout - # - run: apt-get update - # - run: python3 -m pip install --upgrade pip - # - run: apt-get install -y git gcc g++ unixodbc-dev libsasl2-dev - # - run: python3 -m pip install tox - # - run: - # name: Run integration tests - # command: tox -e integration-spark-session - # no_output_timeout: 1h - # - store_artifacts: - # path: ./logs + integration-spark-session: + environment: + DBT_INVOCATION_ENV: circle + docker: + - image: godatadriven/pyspark:3.1 + steps: + - checkout + - run: apt-get update + - run: conda install python=3.10 + - run: python3 -m pip install --upgrade pip + - run: apt-get install -y git gcc g++ unixodbc-dev libsasl2-dev libxml2-dev libxslt-dev + - run: python3 -m pip install tox + - run: + name: Run integration tests + command: tox -e integration-spark-session + no_output_timeout: 1h + - store_artifacts: + path: ./logs integration-spark-thrift: environment: @@ -116,9 +117,9 @@ workflows: test-everything: jobs: - unit - # - integration-spark-session: - # requires: - # - unit + - integration-spark-session: + requires: + - unit - integration-spark-thrift: requires: - unit diff --git a/.github/workflows/docs-issues.yml b/.github/workflows/docs-issues.yml new file mode 100644 index 000000000..00a098df8 --- /dev/null +++ b/.github/workflows/docs-issues.yml @@ -0,0 +1,43 @@ +# **what?** +# Open an issue in docs.getdbt.com when a PR is labeled `user docs` + +# **why?** +# To reduce barriers for keeping docs up to date + +# **when?** +# When a PR is labeled `user docs` and is merged. Runs on pull_request_target to run off the workflow already merged, +# not the workflow that existed on the PR branch. This allows old PRs to get comments. + + +name: Open issues in docs.getdbt.com repo when a PR is labeled +run-name: "Open an issue in docs.getdbt.com for PR #${{ github.event.pull_request.number }}" + +on: + pull_request_target: + types: [labeled, closed] + +defaults: + run: + shell: bash + +permissions: + issues: write # opens new issues + pull-requests: write # comments on PRs + + +jobs: + open_issues: + # we only want to run this when the PR has been merged or the label in the labeled event is `user docs`. Otherwise it runs the + # risk of duplicaton of issues being created due to merge and label both triggering this workflow to run and neither having + # generating the comment before the other runs. This lives here instead of the shared workflow because this is where we + # decide if it should run or not. + if: | + (github.event.pull_request.merged == true) && + ((github.event.action == 'closed' && contains( github.event.pull_request.labels.*.name, 'user docs')) || + (github.event.action == 'labeled' && github.event.label.name == 'user docs')) + uses: dbt-labs/actions/.github/workflows/open-issue-in-repo.yml@main + with: + issue_repository: "dbt-labs/docs.getdbt.com" + issue_title: "Docs Changes Needed from ${{ github.event.repository.name }} PR #${{ github.event.pull_request.number }}" + issue_body: "At a minimum, update body to include a link to the page on docs.getdbt.com requiring updates and what part(s) of the page you would like to see updated." + secrets: inherit diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6b3d93b6e..30126325e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -79,7 +79,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] env: TOXENV: "unit" @@ -177,7 +177,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - name: Set up Python ${{ matrix.python-version }} diff --git a/Makefile b/Makefile index 876440a01..cc1d9f75d 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ dev: ## Installs adapter in develop mode along with development dependencies dev-uninstall: ## Uninstalls all packages while maintaining the virtual environment ## Useful when updating versions, or if you accidentally installed into the system interpreter pip freeze | grep -v "^-e" | cut -d "@" -f1 | xargs pip uninstall -y + pip uninstall -y dbt-spark .PHONY: mypy mypy: ## Runs mypy against staged changes for static type checking. diff --git a/README.md b/README.md index fa286b1f7..2d2586795 100644 --- a/README.md +++ b/README.md @@ -26,18 +26,20 @@ more information, consult [the docs](https://docs.getdbt.com/docs/profile-spark) ## Running locally A `docker-compose` environment starts a Spark Thrift server and a Postgres database as a Hive Metastore backend. -Note: dbt-spark now supports Spark 3.1.1 (formerly on Spark 2.x). +Note: dbt-spark now supports Spark 3.3.2. -The following command would start two docker containers -``` +The following command starts two docker containers: + +```sh docker-compose up -d ``` + It will take a bit of time for the instance to start, you can check the logs of the two containers. If the instance doesn't start correctly, try the complete reset command listed below and then try start again. Create a profile like this one: -``` +```yaml spark_testing: target: local outputs: @@ -60,7 +62,7 @@ Connecting to the local spark instance: Note that the Hive metastore data is persisted under `./.hive-metastore/`, and the Spark-produced data under `./.spark-warehouse/`. To completely reset you environment run the following: -``` +```sh docker-compose down rm -rf ./.hive-metastore/ rm -rf ./.spark-warehouse/ diff --git a/dbt/adapters/spark/__version__.py b/dbt/adapters/spark/__version__.py index 874bd74c8..f15b401d1 100644 --- a/dbt/adapters/spark/__version__.py +++ b/dbt/adapters/spark/__version__.py @@ -1 +1 @@ -version = "1.7.0a1" +version = "1.8.0a1" diff --git a/dbt/adapters/spark/column.py b/dbt/adapters/spark/column.py index bde49a492..a57fa0565 100644 --- a/dbt/adapters/spark/column.py +++ b/dbt/adapters/spark/column.py @@ -3,13 +3,12 @@ from dbt.adapters.base.column import Column from dbt.dataclass_schema import dbtClassMixin -from hologram import JsonDict Self = TypeVar("Self", bound="SparkColumn") @dataclass -class SparkColumn(dbtClassMixin, Column): # type: ignore +class SparkColumn(dbtClassMixin, Column): table_database: Optional[str] = None table_schema: Optional[str] = None table_name: Optional[str] = None @@ -63,7 +62,7 @@ def convert_table_stats(raw_stats: Optional[str]) -> Dict[str, Any]: table_stats[f"stats:{key}:include"] = True return table_stats - def to_column_dict(self, omit_none: bool = True, validate: bool = False) -> JsonDict: + def to_column_dict(self, omit_none: bool = True, validate: bool = False) -> Dict[str, Any]: original_dict = self.to_dict(omit_none=omit_none) # If there are stats, merge them into the root of the dict original_stats = original_dict.pop("table_stats", None) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 6c7899ad9..966f5584e 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -23,16 +23,18 @@ from datetime import datetime import sqlparams from dbt.contracts.connection import Connection -from hologram.helpers import StrEnum +from dbt.dataclass_schema import StrEnum from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable +from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable, Sequence + +from abc import ABC, abstractmethod try: from thrift.transport.TSSLSocket import TSSLSocket import thrift import ssl - import sasl import thrift_sasl + from puresasl.client import SASLClient except ImportError: pass # done deliberately: setting modules to None explicitly violates MyPy contracts by degrading type semantics @@ -57,9 +59,10 @@ class SparkConnectionMethod(StrEnum): @dataclass class SparkCredentials(Credentials): - host: str - method: SparkConnectionMethod - database: Optional[str] # type: ignore + host: Optional[str] = None + schema: Optional[str] = None # type: ignore + method: SparkConnectionMethod = None # type: ignore + database: Optional[str] = None # type: ignore driver: Optional[str] = None cluster: Optional[str] = None endpoint: Optional[str] = None @@ -88,6 +91,13 @@ def cluster_id(self) -> Optional[str]: return self.cluster def __post_init__(self) -> None: + if self.method is None: + raise dbt.exceptions.DbtRuntimeError("Must specify `method` in profile") + if self.host is None: + raise dbt.exceptions.DbtRuntimeError("Must specify `host` in profile") + if self.schema is None: + raise dbt.exceptions.DbtRuntimeError("Must specify `schema` in profile") + # spark classifies database and schema as the same thing if self.database is not None and self.database != self.schema: raise dbt.exceptions.DbtRuntimeError( @@ -152,13 +162,48 @@ def type(self) -> str: @property def unique_field(self) -> str: - return self.host + return self.host # type: ignore def _connection_keys(self) -> Tuple[str, ...]: return "host", "port", "cluster", "endpoint", "schema", "organization" -class PyhiveConnectionWrapper(object): +class SparkConnectionWrapper(ABC): + @abstractmethod + def cursor(self) -> "SparkConnectionWrapper": + pass + + @abstractmethod + def cancel(self) -> None: + pass + + @abstractmethod + def close(self) -> None: + pass + + @abstractmethod + def rollback(self) -> None: + pass + + @abstractmethod + def fetchall(self) -> Optional[List]: + pass + + @abstractmethod + def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None: + pass + + @property + @abstractmethod + def description( + self, + ) -> Sequence[ + Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool] + ]: + pass + + +class PyhiveConnectionWrapper(SparkConnectionWrapper): """Wrap a Spark connection in a way that no-ops transactions""" # https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa @@ -268,7 +313,11 @@ def _fix_binding(cls, value: Any) -> Union[float, str]: return value @property - def description(self) -> Tuple[Tuple[str, Any, int, int, int, int, bool]]: + def description( + self, + ) -> Sequence[ + Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool] + ]: assert self._cursor, "Cursor not available" return self._cursor.description @@ -354,7 +403,7 @@ def open(cls, connection: Connection) -> Connection: creds = connection.credentials exc = None - handle: Any + handle: SparkConnectionWrapper for i in range(1 + creds.connect_retries): try: @@ -398,7 +447,10 @@ def open(cls, connection: Connection) -> Connection: kerberos_service_name=creds.kerberos_service_name, password=creds.password, ) - conn = hive.connect(thrift_transport=transport) + conn = hive.connect( + thrift_transport=transport, + configuration=creds.server_side_parameters, + ) else: conn = hive.connect( host=creds.host, @@ -407,6 +459,7 @@ def open(cls, connection: Connection) -> Connection: auth=creds.auth, kerberos_service_name=creds.kerberos_service_name, password=creds.password, + configuration=creds.server_side_parameters, ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: @@ -558,17 +611,15 @@ def build_ssl_transport( # to be nonempty. password = "x" - def sasl_factory() -> sasl.Client: - sasl_client = sasl.Client() - sasl_client.setAttr("host", host) + def sasl_factory() -> SASLClient: if sasl_auth == "GSSAPI": - sasl_client.setAttr("service", kerberos_service_name) + sasl_client = SASLClient(host, kerberos_service_name, mechanism=sasl_auth) elif sasl_auth == "PLAIN": - sasl_client.setAttr("username", username) - sasl_client.setAttr("password", password) + sasl_client = SASLClient( + host, mechanism=sasl_auth, username=username, password=password + ) else: raise AssertionError - sasl_client.init() return sasl_client transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 2864c4f30..feae34129 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -347,7 +347,9 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, as_dict["table_database"] = None yield as_dict - def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]: + def get_catalog( + self, manifest: Manifest, selected_nodes: Optional[Set] = None + ) -> Tuple[agate.Table, List[Exception]]: schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: raise dbt.exceptions.CompilationError( diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 0e3717172..b5b2bebdb 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -4,11 +4,14 @@ import datetime as dt from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence +from dbt.adapters.spark.connections import SparkConnectionWrapper from dbt.events import AdapterLogger from dbt.utils import DECIMALS +from dbt.exceptions import DbtRuntimeError from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql.utils import AnalysisException logger = AdapterLogger("Spark") @@ -44,13 +47,15 @@ def __exit__( @property def description( self, - ) -> List[Tuple[str, str, None, None, None, None, bool]]: + ) -> Sequence[ + Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool] + ]: """ Get the description. Returns ------- - out : List[Tuple[str, str, None, None, None, None, bool]] + out : Sequence[Tuple[str, str, None, None, None, None, bool]] The description. Source @@ -107,13 +112,18 @@ def execute(self, sql: str, *parameters: Any) -> None: """ if len(parameters) > 0: sql = sql % parameters + builder = SparkSession.builder.enableHiveSupport() for parameter, value in self.server_side_parameters.items(): builder = builder.config(parameter, value) spark_session = builder.getOrCreate() - self._df = spark_session.sql(sql) + + try: + self._df = spark_session.sql(sql) + except AnalysisException as exc: + raise DbtRuntimeError(str(exc)) from exc def fetchall(self) -> Optional[List[Row]]: """ @@ -180,7 +190,7 @@ def cursor(self) -> Cursor: return Cursor(server_side_parameters=self.server_side_parameters) -class SessionConnectionWrapper(object): +class SessionConnectionWrapper(SparkConnectionWrapper): """Connection wrapper for the session connection method.""" handle: Connection @@ -220,7 +230,11 @@ def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None: self._cursor.execute(sql, *bindings) @property - def description(self) -> List[Tuple[str, str, None, None, None, None, bool]]: + def description( + self, + ) -> Sequence[ + Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool] + ]: assert self._cursor, "Cursor not available" return self._cursor.description diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 202564e4e..bfc1f198d 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -1,4 +1,8 @@ -{% macro dbt_spark_tblproperties_clause() -%} +{% macro tblproperties_clause() %} + {{ return(adapter.dispatch('tblproperties_clause', 'dbt')()) }} +{%- endmacro -%} + +{% macro spark__tblproperties_clause() -%} {%- set tblproperties = config.get('tblproperties') -%} {%- if tblproperties is not none %} tblproperties ( @@ -134,7 +138,7 @@ {#-- We can't use temporary tables with `create ... as ()` syntax --#} {% macro spark__create_temporary_view(relation, compiled_code) -%} - create temporary view {{ relation }} as + create or replace temporary view {{ relation }} as {{ compiled_code }} {%- endmacro -%} @@ -156,10 +160,12 @@ {% endif %} {{ file_format_clause() }} {{ options_clause() }} + {{ tblproperties_clause() }} {{ partition_cols(label="partitioned by") }} {{ clustered_cols(label="clustered by") }} {{ location_clause() }} {{ comment_clause() }} + as {{ compiled_code }} {%- endif -%} @@ -223,9 +229,30 @@ {% endfor %} {% endmacro %} +{% macro get_column_comment_sql(column_name, column_dict) -%} + {% if column_name in column_dict and column_dict[column_name]["description"] -%} + {% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %} + {% set column_comment_clause = "comment '" ~ escaped_description ~ "'" %} + {%- endif -%} + {{ adapter.quote(column_name) }} {{ column_comment_clause }} +{% endmacro %} + +{% macro get_persist_docs_column_list(model_columns, query_columns) %} + {% for column_name in query_columns %} + {{ get_column_comment_sql(column_name, model_columns) }} + {{- ", " if not loop.last else "" }} + {% endfor %} +{% endmacro %} {% macro spark__create_view_as(relation, sql) -%} create or replace view {{ relation }} + {% if config.persist_column_docs() -%} + {% set model_columns = model.columns %} + {% set query_columns = get_columns_in_query(sql) %} + ( + {{ get_persist_docs_column_list(model_columns, query_columns) }} + ) + {% endif %} {{ comment_clause() }} {%- set contract_config = config.get('contract') -%} {%- if contract_config.enforced -%} diff --git a/dev-requirements.txt b/dev-requirements.txt index 11fc038f3..6ea7b16aa 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -5,16 +5,16 @@ git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory= # if version 1.x or greater -> pin to major version # if version 0.x -> pin to minor -black~=23.3 +black~=23.9 bumpversion~=0.6.0 click~=8.1 flake8~=6.1;python_version>="3.8" flaky~=3.7 freezegun~=1.2 ipdb~=0.13.13 -mypy==1.3.0 # patch updates have historically introduced breaking changes -pip-tools~=7.2 -pre-commit~=3.3 +mypy==1.5.1 # patch updates have historically introduced breaking changes +pip-tools~=7.3 +pre-commit~=3.4 pre-commit-hooks~=4.4 pytest~=7.4 pytest-csv~=3.0 @@ -22,13 +22,12 @@ pytest-dotenv~=0.5.2 pytest-logbook~=1.2 pytest-xdist~=3.3 pytz~=2023.3 -tox~=4.5 +tox~=4.11 types-pytz~=2023.3 -types-requests~=2.29 +types-requests~=2.31 twine~=4.0 -wheel~=0.40 +wheel~=0.41 # Adapter specific dependencies -mock~=5.0 -sasl~=0.3.1 +mock~=5.1 thrift_sasl~=0.4.3 diff --git a/docker-compose.yml b/docker-compose.yml index 9bc9e509c..ad083eaf4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: "3.7" services: dbt-spark3-thrift: - image: godatadriven/spark:3.1.1 + build: docker/ ports: - "10000:10000" - "4040:4040" @@ -19,7 +19,7 @@ services: - WAIT_FOR=dbt-hive-metastore:5432 dbt-hive-metastore: - image: postgres:9.6.17-alpine + image: postgres:9-alpine volumes: - ./.hive-metastore/:/var/lib/postgresql/data environment: diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 000000000..bb4d378ed --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,30 @@ +ARG OPENJDK_VERSION=8 +FROM eclipse-temurin:${OPENJDK_VERSION}-jre + +ARG BUILD_DATE +ARG SPARK_VERSION=3.3.2 +ARG HADOOP_VERSION=3 + +LABEL org.label-schema.name="Apache Spark ${SPARK_VERSION}" \ + org.label-schema.build-date=$BUILD_DATE \ + org.label-schema.version=$SPARK_VERSION + +ENV SPARK_HOME /usr/spark +ENV PATH="/usr/spark/bin:/usr/spark/sbin:${PATH}" + +RUN apt-get update && \ + apt-get install -y wget netcat procps libpostgresql-jdbc-java && \ + wget -q "http://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz" && \ + tar xzf "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz" && \ + rm "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz" && \ + mv "spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}" /usr/spark && \ + ln -s /usr/share/java/postgresql-jdbc4.jar /usr/spark/jars/postgresql-jdbc4.jar && \ + apt-get remove -y wget && \ + apt-get autoremove -y && \ + apt-get clean + +COPY entrypoint.sh /scripts/ +RUN chmod +x /scripts/entrypoint.sh + +ENTRYPOINT ["/scripts/entrypoint.sh"] +CMD ["--help"] diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 000000000..6a7591389 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +if [ -n "$WAIT_FOR" ]; then + IFS=';' read -a HOSTPORT_ARRAY <<< "$WAIT_FOR" + for HOSTPORT in "${HOSTPORT_ARRAY[@]}" + do + WAIT_FOR_HOST=${HOSTPORT%:*} + WAIT_FOR_PORT=${HOSTPORT#*:} + + echo Waiting for $WAIT_FOR_HOST to listen on $WAIT_FOR_PORT... + while ! nc -z $WAIT_FOR_HOST $WAIT_FOR_PORT; do echo sleeping; sleep 2; done + done +fi + +exec spark-submit "$@" diff --git a/requirements.txt b/requirements.txt index e58ecdd4b..ea5d1ad2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -PyHive[hive]>=0.6.0,<0.7.0 -requests[python]>=2.28.1 +pyhive[hive_pure_sasl]~=0.7.0 +requests>=2.28.1 pyodbc~=4.0.39 sqlparams>=3.0.0 diff --git a/setup.py b/setup.py index c6713e895..301b4a41f 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,6 @@ print("Please upgrade to Python 3.8 or higher.") sys.exit(1) - # require version of setuptools that supports find_namespace_packages from setuptools import setup @@ -50,13 +49,13 @@ def _get_dbt_core_version(): package_name = "dbt-spark" -package_version = "1.7.0a1" +package_version = "1.8.0a1" dbt_core_version = _get_dbt_core_version() description = """The Apache Spark adapter plugin for dbt""" -odbc_extras = ["pyodbc~=4.0.30"] +odbc_extras = ["pyodbc~=4.0.39"] pyhive_extras = [ - "PyHive[hive]>=0.6.0,<0.7.0", + "PyHive[hive_pure_sasl]~=0.7.0", "thrift>=0.11.0,<0.17.0", ] session_extras = ["pyspark>=3.0.0,<4.0.0"] @@ -93,6 +92,7 @@ def _get_dbt_core_version(): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], python_requires=">=3.8", ) diff --git a/tests/functional/adapter/persist_docs/fixtures.py b/tests/functional/adapter/persist_docs/fixtures.py index 3c351ab55..b884b7dec 100644 --- a/tests/functional/adapter/persist_docs/fixtures.py +++ b/tests/functional/adapter/persist_docs/fixtures.py @@ -21,11 +21,39 @@ select 1 as id, 'Joe' as name """ +_MODELS__VIEW_DELTA_MODEL = """ +{{ config(materialized='view') }} +select id, count(*) as count from {{ ref('table_delta_model') }} group by id +""" + _MODELS__TABLE_DELTA_MODEL_MISSING_COLUMN = """ {{ config(materialized='table', file_format='delta') }} select 1 as id, 'Joe' as different_name """ +_VIEW_PROPERTIES_MODELS = """ +version: 2 +models: + - name: view_delta_model + description: | + View model description "with double quotes" + and with 'single quotes' as welll as other; + '''abc123''' + reserved -- characters + -- + /* comment */ + Some $lbl$ labeled $lbl$ and $$ unlabeled $$ dollar-quoting + columns: + - name: id + description: | + id Column description "with double quotes" + and with 'single quotes' as welll as other; + '''abc123''' + reserved -- characters + -- + /* comment */ + Some $lbl$ labeled $lbl$ and $$ unlabeled $$ dollar-quoting +""" _PROPERTIES__MODELS = """ version: 2 diff --git a/tests/functional/adapter/persist_docs/test_persist_docs.py b/tests/functional/adapter/persist_docs/test_persist_docs.py index 0e3d102dc..ee02e5ef8 100644 --- a/tests/functional/adapter/persist_docs/test_persist_docs.py +++ b/tests/functional/adapter/persist_docs/test_persist_docs.py @@ -10,6 +10,8 @@ _PROPERTIES__MODELS, _PROPERTIES__SEEDS, _SEEDS__BASIC, + _MODELS__VIEW_DELTA_MODEL, + _VIEW_PROPERTIES_MODELS, ) @@ -76,6 +78,48 @@ def test_delta_comments(self, project): assert result[2].startswith("Some stuff here and then a call to") +@pytest.mark.skip_profile("apache_spark", "spark_session") +class TestPersistDocsDeltaView: + @pytest.fixture(scope="class") + def models(self): + return { + "table_delta_model.sql": _MODELS__TABLE_DELTA_MODEL, + "view_delta_model.sql": _MODELS__VIEW_DELTA_MODEL, + "schema.yml": _VIEW_PROPERTIES_MODELS, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "test": { + "+persist_docs": { + "relation": True, + "columns": True, + }, + } + }, + } + + def test_delta_comments(self, project): + run_dbt(["run"]) + + results = project.run_sql( + "describe extended {schema}.{table}".format( + schema=project.test_schema, table="view_delta_model" + ), + fetch="all", + ) + + for result in results: + if result[0] == "Comment": + assert result[1].startswith("View model description") + if result[0] == "id": + assert result[2].startswith("id Column description") + if result[0] == "count": + assert result[2] is None + + @pytest.mark.skip_profile("apache_spark", "spark_session") class TestPersistDocsMissingColumn: @pytest.fixture(scope="class") diff --git a/tests/functional/adapter/store_test_failures_tests/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py similarity index 67% rename from tests/functional/adapter/store_test_failures_tests/test_store_test_failures.py rename to tests/functional/adapter/test_store_test_failures.py index a5342f215..822cb57a2 100644 --- a/tests/functional/adapter/store_test_failures_tests/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -1,5 +1,6 @@ import pytest +from dbt.tests.adapter.store_test_failures_tests import basic from dbt.tests.adapter.store_test_failures_tests.test_store_test_failures import ( StoreTestFailuresBase, TEST_AUDIT_SCHEMA_SUFFIX, @@ -53,3 +54,33 @@ def project_config_update(self): def test_store_and_assert_failure_with_delta(self, project): self.run_tests_store_one_failure(project) self.run_tests_store_failures_and_assert(project) + + +@pytest.mark.skip_profile("spark_session") +class TestStoreTestFailuresAsInteractions(basic.StoreTestFailuresAsInteractions): + pass + + +@pytest.mark.skip_profile("spark_session") +class TestStoreTestFailuresAsProjectLevelOff(basic.StoreTestFailuresAsProjectLevelOff): + pass + + +@pytest.mark.skip_profile("spark_session") +class TestStoreTestFailuresAsProjectLevelView(basic.StoreTestFailuresAsProjectLevelView): + pass + + +@pytest.mark.skip_profile("spark_session") +class TestStoreTestFailuresAsGeneric(basic.StoreTestFailuresAsGeneric): + pass + + +@pytest.mark.skip_profile("spark_session") +class TestStoreTestFailuresAsProjectLevelEphemeral(basic.StoreTestFailuresAsProjectLevelEphemeral): + pass + + +@pytest.mark.skip_profile("spark_session") +class TestStoreTestFailuresAsExceptions(basic.StoreTestFailuresAsExceptions): + pass diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 1eb818241..a7da63301 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -173,13 +173,16 @@ def test_thrift_connection(self): config = self._get_target_thrift(self.project_cfg) adapter = SparkAdapter(config) - def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + def hive_thrift_connect( + host, port, username, auth, kerberos_service_name, password, configuration + ): self.assertEqual(host, "myorg.sparkhost.com") self.assertEqual(port, 10001) self.assertEqual(username, "dbt") self.assertIsNone(auth) self.assertIsNone(kerberos_service_name) self.assertIsNone(password) + self.assertDictEqual(configuration, {}) with mock.patch.object(hive, "connect", new=hive_thrift_connect): connection = adapter.acquire_connection("dummy") @@ -194,11 +197,12 @@ def test_thrift_ssl_connection(self): config = self._get_target_use_ssl_thrift(self.project_cfg) adapter = SparkAdapter(config) - def hive_thrift_connect(thrift_transport): + def hive_thrift_connect(thrift_transport, configuration): self.assertIsNotNone(thrift_transport) transport = thrift_transport._trans self.assertEqual(transport.host, "myorg.sparkhost.com") self.assertEqual(transport.port, 10001) + self.assertDictEqual(configuration, {}) with mock.patch.object(hive, "connect", new=hive_thrift_connect): connection = adapter.acquire_connection("dummy") @@ -213,13 +217,16 @@ def test_thrift_connection_kerberos(self): config = self._get_target_thrift_kerberos(self.project_cfg) adapter = SparkAdapter(config) - def hive_thrift_connect(host, port, username, auth, kerberos_service_name, password): + def hive_thrift_connect( + host, port, username, auth, kerberos_service_name, password, configuration + ): self.assertEqual(host, "myorg.sparkhost.com") self.assertEqual(port, 10001) self.assertEqual(username, "dbt") self.assertEqual(auth, "KERBEROS") self.assertEqual(kerberos_service_name, "hive") self.assertIsNone(password) + self.assertDictEqual(configuration, {}) with mock.patch.object(hive, "connect", new=hive_thrift_connect): connection = adapter.acquire_connection("dummy") @@ -710,6 +717,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel config = self._get_target_http(self.project_cfg) columns = SparkAdapter(config).parse_columns_from_information(relation) self.assertEqual(len(columns), 4) + self.assertEqual( columns[2].to_column_dict(omit_none=False), {