diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 744284849..4de02c345 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -4,7 +4,7 @@ parse = (?P\d+) \.(?P\d+) \.(?P\d+) ((?Pa|b|rc)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} commit = False @@ -13,7 +13,7 @@ tag = False [bumpversion:part:prerelease] first_value = a optional_value = final -values = +values = a b rc @@ -25,4 +25,3 @@ first_value = 1 [bumpversion:file:setup.py] [bumpversion:file:dbt/adapters/spark/__version__.py] - diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..f39d154c0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +select = + E + W + F +ignore = + W503 # makes Flake8 work like black + W504 + E203 # makes Flake8 work like black + E741 + E501 +exclude = test diff --git a/.github/ISSUE_TEMPLATE/dependabot.yml b/.github/ISSUE_TEMPLATE/dependabot.yml index 8a8c85b9f..2a6f34492 100644 --- a/.github/ISSUE_TEMPLATE/dependabot.yml +++ b/.github/ISSUE_TEMPLATE/dependabot.yml @@ -5,4 +5,4 @@ updates: directory: "/" schedule: interval: "daily" - rebase-strategy: "disabled" \ No newline at end of file + rebase-strategy: "disabled" diff --git a/.github/ISSUE_TEMPLATE/release.md b/.github/ISSUE_TEMPLATE/release.md index ac28792a3..a69349f54 100644 --- a/.github/ISSUE_TEMPLATE/release.md +++ b/.github/ISSUE_TEMPLATE/release.md @@ -7,4 +7,4 @@ assignees: '' --- -### TBD \ No newline at end of file +### TBD diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 60e12779b..5928b1cbf 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -18,4 +18,4 @@ resolves # - [ ] I have signed the [CLA](https://docs.getdbt.com/docs/contributor-license-agreements) - [ ] I have run this code in development and it appears to resolve the stated issue - [ ] This PR includes tests, or tests are not required/relevant for this PR -- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-spark next" section. \ No newline at end of file +- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-spark next" section. diff --git a/.github/workflows/jira-creation.yml b/.github/workflows/jira-creation.yml index c84e106a7..b4016befc 100644 --- a/.github/workflows/jira-creation.yml +++ b/.github/workflows/jira-creation.yml @@ -13,7 +13,7 @@ name: Jira Issue Creation on: issues: types: [opened, labeled] - + permissions: issues: write diff --git a/.github/workflows/jira-label.yml b/.github/workflows/jira-label.yml index fd533a170..3da2e3a38 100644 --- a/.github/workflows/jira-label.yml +++ b/.github/workflows/jira-label.yml @@ -13,7 +13,7 @@ name: Jira Label Mirroring on: issues: types: [labeled, unlabeled] - + permissions: issues: read @@ -24,4 +24,3 @@ jobs: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} - diff --git a/.github/workflows/jira-transition.yml b/.github/workflows/jira-transition.yml index 71273c7a9..ed9f9cd4f 100644 --- a/.github/workflows/jira-transition.yml +++ b/.github/workflows/jira-transition.yml @@ -21,4 +21,4 @@ jobs: secrets: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} - JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} \ No newline at end of file + JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fbdbbbaae..56685bfc6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,19 +37,10 @@ defaults: jobs: code-quality: - name: ${{ matrix.toxenv }} + name: code-quality runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - toxenv: [flake8] - - env: - TOXENV: ${{ matrix.toxenv }} - PYTEST_ADDOPTS: "-v --color=yes" - steps: - name: Check out the repository uses: actions/checkout@v2 @@ -58,18 +49,19 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 - with: + with: python-version: '3.8' - name: Install python dependencies run: | sudo apt-get install libsasl2-dev pip install --user --upgrade pip - pip install tox - pip --version - tox --version - - name: Run tox - run: tox + pip install -r dev-requirements.txt + pre-commit --version + mypy --version + dbt --version + - name: pre-commit hooks + run: pre-commit run --all-files --show-diff-on-failure unit: name: unit test / python ${{ matrix.python-version }} @@ -153,7 +145,7 @@ jobs: - name: Check wheel contents run: | check-wheel-contents dist/*.whl --ignore W007,W008 - + - name: Check if this is an alpha version id: check-is-alpha run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b40371e8a..554e13a8d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -116,4 +116,3 @@ jobs: pip install twine wheel setuptools python setup.py sdist bdist_wheel twine upload --non-interactive dist/dbt_spark-${{env.version_number}}-py3-none-any.whl dist/dbt-spark-${{env.version_number}}.tar.gz - diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2848ce8f7..a56455d55 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,5 +13,3 @@ jobs: stale-pr-message: "This PR has been marked as Stale because it has been open for 180 days with no activity. If you would like the PR to remain open, please remove the stale label or comment on the PR, or it will be closed in 7 days." # mark issues/PRs stale when they haven't seen activity in 180 days days-before-stale: 180 - # ignore checking issues with the following labels - exempt-issue-labels: "epic, discussion" \ No newline at end of file diff --git a/.gitignore b/.gitignore index cc586f5fe..189589cf4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,47 @@ -.hive-metastore/ -.spark-warehouse/ -*.egg-info -env/ -*.pyc +# Byte-compiled / optimized / DLL files __pycache__ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +env*/ +dbt_env/ +dist/ +*.egg-info +logs/ + + +# Unit test .tox/ .env +test.env + + +# Django stuff +*.log + +# Mypy +*.pytest_cache/ + +# Vim +*.sw* + +# Pyenv +.python-version + +# pycharm .idea/ -build/ -dist/ -dbt-integration-tests -test/integration/.user.yml + +# MacOS .DS_Store -test.env + +# vscode .vscode -*.log -logs/ \ No newline at end of file + +# other +.hive-metastore/ +.spark-warehouse/ +dbt-integration-tests +test/integration/.user.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..e70156dcd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,66 @@ +# For more on configuring pre-commit hooks (see https://pre-commit.com/) + +# TODO: remove global exclusion of tests when testing overhaul is complete +exclude: '^tests/.*' + +# Force all unspecified python hooks to run python 3.8 +default_language_version: + python: python3.8 + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: check-yaml + args: [--unsafe] + - id: check-json + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-case-conflict +- repo: https://github.com/psf/black + rev: 21.12b0 + hooks: + - id: black + additional_dependencies: ['click==8.0.4'] + args: + - "--line-length=99" + - "--target-version=py38" + - id: black + alias: black-check + stages: [manual] + additional_dependencies: ['click==8.0.4'] + args: + - "--line-length=99" + - "--target-version=py38" + - "--check" + - "--diff" +- repo: https://gitlab.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - id: flake8 + alias: flake8-check + stages: [manual] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.950 + hooks: + - id: mypy + # N.B.: Mypy is... a bit fragile. + # + # By using `language: system` we run this hook in the local + # environment instead of a pre-commit isolated one. This is needed + # to ensure mypy correctly parses the project. + + # It may cause trouble in that it adds environmental variables out + # of our control to the mix. Unfortunately, there's nothing we can + # do about per pre-commit's author. + # See https://github.com/pre-commit/pre-commit/issues/730 for details. + args: [--show-error-codes, --ignore-missing-imports] + files: ^dbt/adapters/.* + language: system + - id: mypy + alias: mypy-check + stages: [manual] + args: [--show-error-codes, --pretty, --ignore-missing-imports] + files: ^dbt/adapters + language: system diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ad68a5ce..77eb72581 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - rename file to match reference to dbt-core ([#344](https://github.com/dbt-labs/dbt-spark/pull/344)) ### Under the hood +- Add precommit tooling to this repo ([#356](https://github.com/dbt-labs/dbt-spark/pull/356)) - Use dbt.tests.adapter.basic in test suite ([#298](https://github.com/dbt-labs/dbt-spark/issues/298), [#299](https://github.com/dbt-labs/dbt-spark/pull/299)) - Make internal macros use macro dispatch to be overridable in child adapters ([#319](https://github.com/dbt-labs/dbt-spark/issues/319), [#320](https://github.com/dbt-labs/dbt-spark/pull/320)) - Override adapter method 'run_sql_for_tests' ([#323](https://github.com/dbt-labs/dbt-spark/issues/323), [#324](https://github.com/dbt-labs/dbt-spark/pull/324)) diff --git a/MANIFEST.in b/MANIFEST.in index 78412d5b8..cfbc714ed 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include dbt/include *.sql *.yml *.md \ No newline at end of file +recursive-include dbt/include *.sql *.yml *.md diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..a520c425f --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +.DEFAULT_GOAL:=help + +.PHONY: dev +dev: ## Installs adapter in develop mode along with development depedencies + @\ + pip install -r dev-requirements.txt && pre-commit install + +.PHONY: mypy +mypy: ## Runs mypy against staged changes for static type checking. + @\ + pre-commit run --hook-stage manual mypy-check | grep -v "INFO" + +.PHONY: flake8 +flake8: ## Runs flake8 against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual flake8-check | grep -v "INFO" + +.PHONY: black +black: ## Runs black against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual black-check -v | grep -v "INFO" + +.PHONY: lint +lint: ## Runs flake8 and mypy code checks against staged changes. + @\ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" + +.PHONY: linecheck +linecheck: ## Checks for all Python lines 100 characters or more + @\ + find dbt -type f -name "*.py" -exec grep -I -r -n '.\{100\}' {} \; + +.PHONY: unit +unit: ## Runs unit tests with py38. + @\ + tox -e py38 + +.PHONY: test +test: ## Runs unit tests with py38 and code checks against staged changes. + @\ + tox -p -e py38; \ + pre-commit run black-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" + +.PHONY: clean + @echo "cleaning repo" + @git clean -f -X + +.PHONY: help +help: ## Show this help message. + @echo 'usage: make [target]' + @echo + @echo 'targets:' + @grep -E '^[7+a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/dbt/adapters/spark/__init__.py b/dbt/adapters/spark/__init__.py index 469e202b9..6ecc5eccf 100644 --- a/dbt/adapters/spark/__init__.py +++ b/dbt/adapters/spark/__init__.py @@ -8,6 +8,5 @@ from dbt.include import spark Plugin = AdapterPlugin( - adapter=SparkAdapter, - credentials=SparkCredentials, - include_path=spark.PACKAGE_PATH) + adapter=SparkAdapter, credentials=SparkCredentials, include_path=spark.PACKAGE_PATH +) diff --git a/dbt/adapters/spark/column.py b/dbt/adapters/spark/column.py index fd377ad15..4df6b301b 100644 --- a/dbt/adapters/spark/column.py +++ b/dbt/adapters/spark/column.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import TypeVar, Optional, Dict, Any +from typing import Any, Dict, Optional, TypeVar, Union from dbt.adapters.base.column import Column from dbt.dataclass_schema import dbtClassMixin from hologram import JsonDict -Self = TypeVar('Self', bound='SparkColumn') +Self = TypeVar("Self", bound="SparkColumn") @dataclass @@ -31,7 +31,7 @@ def literal(self, value): @property def quoted(self) -> str: - return '`{}`'.format(self.column) + return "`{}`".format(self.column) @property def data_type(self) -> str: @@ -42,26 +42,23 @@ def __repr__(self) -> str: @staticmethod def convert_table_stats(raw_stats: Optional[str]) -> Dict[str, Any]: - table_stats = {} + table_stats: Dict[str, Union[int, str, bool]] = {} if raw_stats: # format: 1109049927 bytes, 14093476 rows stats = { - stats.split(" ")[1]: int(stats.split(" ")[0]) - for stats in raw_stats.split(', ') + stats.split(" ")[1]: int(stats.split(" ")[0]) for stats in raw_stats.split(", ") } for key, val in stats.items(): - table_stats[f'stats:{key}:label'] = key - table_stats[f'stats:{key}:value'] = val - table_stats[f'stats:{key}:description'] = '' - table_stats[f'stats:{key}:include'] = True + table_stats[f"stats:{key}:label"] = key + table_stats[f"stats:{key}:value"] = val + table_stats[f"stats:{key}:description"] = "" + table_stats[f"stats:{key}:include"] = True return table_stats - def to_column_dict( - self, omit_none: bool = True, validate: bool = False - ) -> JsonDict: + def to_column_dict(self, omit_none: bool = True, validate: bool = False) -> JsonDict: original_dict = self.to_dict(omit_none=omit_none) # If there are stats, merge them into the root of the dict - original_stats = original_dict.pop('table_stats', None) + original_stats = original_dict.pop("table_stats", None) if original_stats: original_dict.update(original_stats) return original_dict diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 11163ccf0..59ceb9dd8 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -26,6 +26,7 @@ from hologram.helpers import StrEnum from dataclasses import dataclass, field from typing import Any, Dict, Optional + try: from thrift.transport.TSSLSocket import TSSLSocket import thrift @@ -33,11 +34,7 @@ import sasl import thrift_sasl except ImportError: - TSSLSocket = None - thrift = None - ssl = None - sasl = None - thrift_sasl = None + pass # done deliberately: setting modules to None explicitly violates MyPy contracts by degrading type semantics import base64 import time @@ -52,10 +49,10 @@ def _build_odbc_connnection_string(**kwargs) -> str: class SparkConnectionMethod(StrEnum): - THRIFT = 'thrift' - HTTP = 'http' - ODBC = 'odbc' - SESSION = 'session' + THRIFT = "thrift" + HTTP = "http" + ODBC = "odbc" + SESSION = "session" @dataclass @@ -71,7 +68,7 @@ class SparkCredentials(Credentials): port: int = 443 auth: Optional[str] = None kerberos_service_name: Optional[str] = None - organization: str = '0' + organization: str = "0" connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -81,27 +78,24 @@ class SparkCredentials(Credentials): @classmethod def __pre_deserialize__(cls, data): data = super().__pre_deserialize__(data) - if 'database' not in data: - data['database'] = None + if "database" not in data: + data["database"] = None return data def __post_init__(self): # spark classifies database and schema as the same thing - if ( - self.database is not None and - self.database != self.schema - ): + if self.database is not None and self.database != self.schema: raise dbt.exceptions.RuntimeException( - f' schema: {self.schema} \n' - f' database: {self.database} \n' - f'On Spark, database must be omitted or have the same value as' - f' schema.' + f" schema: {self.schema} \n" + f" database: {self.database} \n" + f"On Spark, database must be omitted or have the same value as" + f" schema." ) self.database = None if self.method == SparkConnectionMethod.ODBC: try: - import pyodbc # noqa: F401 + import pyodbc # noqa: F401 except ImportError as e: raise dbt.exceptions.RuntimeException( f"{self.method} connection method requires " @@ -111,22 +105,16 @@ def __post_init__(self): f"ImportError({e.msg})" ) from e - if ( - self.method == SparkConnectionMethod.ODBC and - self.cluster and - self.endpoint - ): + if self.method == SparkConnectionMethod.ODBC and self.cluster and self.endpoint: raise dbt.exceptions.RuntimeException( "`cluster` and `endpoint` cannot both be set when" f" using {self.method} method to connect to Spark" ) if ( - self.method == SparkConnectionMethod.HTTP or - self.method == SparkConnectionMethod.THRIFT - ) and not ( - ThriftState and THttpClient and hive - ): + self.method == SparkConnectionMethod.HTTP + or self.method == SparkConnectionMethod.THRIFT + ) and not (ThriftState and THttpClient and hive): raise dbt.exceptions.RuntimeException( f"{self.method} connection method requires " "additional dependencies. \n" @@ -148,19 +136,19 @@ def __post_init__(self): @property def type(self): - return 'spark' + return "spark" @property def unique_field(self): return self.host def _connection_keys(self): - return ('host', 'port', 'cluster', - 'endpoint', 'schema', 'organization') + return ("host", "port", "cluster", "endpoint", "schema", "organization") class PyhiveConnectionWrapper(object): """Wrap a Spark connection in a way that no-ops transactions""" + # https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa def __init__(self, handle): @@ -178,9 +166,7 @@ def cancel(self): try: self._cursor.cancel() except EnvironmentError as exc: - logger.debug( - "Exception while cancelling query: {}".format(exc) - ) + logger.debug("Exception while cancelling query: {}".format(exc)) def close(self): if self._cursor: @@ -189,9 +175,7 @@ def close(self): try: self._cursor.close() except EnvironmentError as exc: - logger.debug( - "Exception while closing cursor: {}".format(exc) - ) + logger.debug("Exception while closing cursor: {}".format(exc)) self.handle.close() def rollback(self, *args, **kwargs): @@ -247,23 +231,20 @@ def execute(self, sql, bindings=None): dbt.exceptions.raise_database_error(poll_state.errorMessage) elif state not in STATE_SUCCESS: - status_type = ThriftState._VALUES_TO_NAMES.get( - state, - 'Unknown<{!r}>'.format(state)) + status_type = ThriftState._VALUES_TO_NAMES.get(state, "Unknown<{!r}>".format(state)) - dbt.exceptions.raise_database_error( - "Query failed with status: {}".format(status_type)) + dbt.exceptions.raise_database_error("Query failed with status: {}".format(status_type)) logger.debug("Poll status: {}, query complete".format(state)) @classmethod def _fix_binding(cls, value): """Convert complex datatypes to primitives that can be loaded by - the Spark driver""" + the Spark driver""" if isinstance(value, NUMBERS): return float(value) elif isinstance(value, datetime): - return value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + return value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] else: return value @@ -273,7 +254,6 @@ def description(self): class PyodbcConnectionWrapper(PyhiveConnectionWrapper): - def execute(self, sql, bindings=None): if sql.strip().endswith(";"): sql = sql.strip()[:-1] @@ -282,19 +262,17 @@ def execute(self, sql, bindings=None): self._cursor.execute(sql) else: # pyodbc only supports `qmark` sql params! - query = sqlparams.SQLParams('format', 'qmark') + query = sqlparams.SQLParams("format", "qmark") sql, bindings = query.format(sql, bindings) self._cursor.execute(sql, *bindings) class SparkConnectionManager(SQLConnectionManager): - TYPE = 'spark' + TYPE = "spark" SPARK_CLUSTER_HTTP_PATH = "/sql/protocolv1/o/{organization}/{cluster}" SPARK_SQL_ENDPOINT_HTTP_PATH = "/sql/1.0/endpoints/{endpoint}" - SPARK_CONNECTION_URL = ( - "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH - ) + SPARK_CONNECTION_URL = "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH @contextmanager def exception_handler(self, sql): @@ -308,7 +286,7 @@ def exception_handler(self, sql): raise thrift_resp = exc.args[0] - if hasattr(thrift_resp, 'status'): + if hasattr(thrift_resp, "status"): msg = thrift_resp.status.errorMessage raise dbt.exceptions.RuntimeException(msg) else: @@ -320,10 +298,8 @@ def cancel(self, connection): @classmethod def get_response(cls, cursor) -> AdapterResponse: # https://github.com/dbt-labs/dbt-spark/issues/142 - message = 'OK' - return AdapterResponse( - _message=message - ) + message = "OK" + return AdapterResponse(_message=message) # No transactions on Spark.... def add_begin_query(self, *args, **kwargs): @@ -346,12 +322,13 @@ def validate_creds(cls, creds, required): if not hasattr(creds, key): raise dbt.exceptions.DbtProfileError( "The config '{}' is required when using the {} method" - " to connect to Spark".format(key, method)) + " to connect to Spark".format(key, method) + ) @classmethod def open(cls, connection): if connection.state == ConnectionState.OPEN: - logger.debug('Connection is already open, skipping open.') + logger.debug("Connection is already open, skipping open.") return connection creds = connection.credentials @@ -360,19 +337,18 @@ def open(cls, connection): for i in range(1 + creds.connect_retries): try: if creds.method == SparkConnectionMethod.HTTP: - cls.validate_creds(creds, ['token', 'host', 'port', - 'cluster', 'organization']) + cls.validate_creds(creds, ["token", "host", "port", "cluster", "organization"]) # Prepend https:// if it is missing host = creds.host - if not host.startswith('https://'): - host = 'https://' + creds.host + if not host.startswith("https://"): + host = "https://" + creds.host conn_url = cls.SPARK_CONNECTION_URL.format( host=host, port=creds.port, organization=creds.organization, - cluster=creds.cluster + cluster=creds.cluster, ) logger.debug("connection url: {}".format(conn_url)) @@ -381,15 +357,12 @@ def open(cls, connection): raw_token = "token:{}".format(creds.token).encode() token = base64.standard_b64encode(raw_token).decode() - transport.setCustomHeaders({ - 'Authorization': 'Basic {}'.format(token) - }) + transport.setCustomHeaders({"Authorization": "Basic {}".format(token)}) conn = hive.connect(thrift_transport=transport) handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.THRIFT: - cls.validate_creds(creds, - ['host', 'port', 'user', 'schema']) + cls.validate_creds(creds, ["host", "port", "user", "schema"]) if creds.use_ssl: transport = build_ssl_transport( @@ -397,26 +370,33 @@ def open(cls, connection): port=creds.port, username=creds.user, auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) + kerberos_service_name=creds.kerberos_service_name, + ) conn = hive.connect(thrift_transport=transport) else: - conn = hive.connect(host=creds.host, - port=creds.port, - username=creds.user, - auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) # noqa + conn = hive.connect( + host=creds.host, + port=creds.port, + username=creds.user, + auth=creds.auth, + kerberos_service_name=creds.kerberos_service_name, + ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: if creds.cluster is not None: - required_fields = ['driver', 'host', 'port', 'token', - 'organization', 'cluster'] + required_fields = [ + "driver", + "host", + "port", + "token", + "organization", + "cluster", + ] http_path = cls.SPARK_CLUSTER_HTTP_PATH.format( - organization=creds.organization, - cluster=creds.cluster + organization=creds.organization, cluster=creds.cluster ) elif creds.endpoint is not None: - required_fields = ['driver', 'host', 'port', 'token', - 'endpoint'] + required_fields = ["driver", "host", "port", "token", "endpoint"] http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) @@ -429,13 +409,12 @@ def open(cls, connection): cls.validate_creds(creds, required_fields) dbt_spark_version = __version__.version - user_agent_entry = f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa + user_agent_entry = ( + f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa + ) # http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm - ssp = { - f"SSP_{k}": f"{{{v}}}" - for k, v in creds.server_side_parameters.items() - } + ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()} # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm connection_str = _build_odbc_connnection_string( @@ -461,6 +440,7 @@ def open(cls, connection): Connection, SessionConnectionWrapper, ) + handle = SessionConnectionWrapper(Connection()) else: raise dbt.exceptions.DbtProfileError( @@ -472,9 +452,9 @@ def open(cls, connection): if isinstance(e, EOFError): # The user almost certainly has invalid credentials. # Perhaps a token expired, or something - msg = 'Failed to connect' + msg = "Failed to connect" if creds.token is not None: - msg += ', is your token valid?' + msg += ", is your token valid?" raise dbt.exceptions.FailedToConnectException(msg) from e retryable_message = _is_retryable_error(e) if retryable_message and creds.connect_retries > 0: @@ -496,9 +476,7 @@ def open(cls, connection): logger.warning(msg) time.sleep(creds.connect_timeout) else: - raise dbt.exceptions.FailedToConnectException( - 'failed to connect' - ) from e + raise dbt.exceptions.FailedToConnectException("failed to connect") from e else: raise exc @@ -507,56 +485,50 @@ def open(cls, connection): return connection -def build_ssl_transport(host, port, username, auth, - kerberos_service_name, password=None): +def build_ssl_transport(host, port, username, auth, kerberos_service_name, password=None): transport = None if port is None: port = 10000 if auth is None: - auth = 'NONE' + auth = "NONE" socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE) - if auth == 'NOSASL': + if auth == "NOSASL": # NOSASL corresponds to hive.server2.authentication=NOSASL # in hive-site.xml transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): + elif auth in ("LDAP", "KERBEROS", "NONE", "CUSTOM"): # Defer import so package dependency is optional - if auth == 'KERBEROS': + if auth == "KERBEROS": # KERBEROS mode in hive.server2.authentication is GSSAPI # in sasl library - sasl_auth = 'GSSAPI' + sasl_auth = "GSSAPI" else: - sasl_auth = 'PLAIN' + sasl_auth = "PLAIN" if password is None: # Password doesn't matter in NONE mode, just needs # to be nonempty. - password = 'x' + password = "x" def sasl_factory(): sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) + sasl_client.setAttr("host", host) + if sasl_auth == "GSSAPI": + sasl_client.setAttr("service", kerberos_service_name) + elif sasl_auth == "PLAIN": + sasl_client.setAttr("username", username) + sasl_client.setAttr("password", password) else: raise AssertionError sasl_client.init() return sasl_client - transport = thrift_sasl.TSaslClientTransport(sasl_factory, - sasl_auth, socket) + transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) return transport -def _is_retryable_error(exc: Exception) -> Optional[str]: - message = getattr(exc, 'message', None) - if message is None: - return None - message = message.lower() - if 'pending' in message: - return exc.message - if 'temporarily_unavailable' in message: - return exc.message - return None +def _is_retryable_error(exc: Exception) -> str: + message = str(exc).lower() + if "pending" in message or "temporarily_unavailable" in message: + return str(exc) + else: + return "" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index eb001fbc9..dd090a23b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,7 +1,9 @@ import re from concurrent.futures import Future from dataclasses import dataclass -from typing import Optional, List, Dict, Any, Union, Iterable +from typing import Any, Dict, Iterable, List, Optional, Union +from typing_extensions import TypeAlias + import agate from dbt.contracts.relation import RelationType @@ -21,19 +23,19 @@ logger = AdapterLogger("Spark") -GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation' -LIST_SCHEMAS_MACRO_NAME = 'list_schemas' -LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' -DROP_RELATION_MACRO_NAME = 'drop_relation' -FETCH_TBL_PROPERTIES_MACRO_NAME = 'fetch_tbl_properties' +GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation" +LIST_SCHEMAS_MACRO_NAME = "list_schemas" +LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" +DROP_RELATION_MACRO_NAME = "drop_relation" +FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties" -KEY_TABLE_OWNER = 'Owner' -KEY_TABLE_STATISTICS = 'Statistics' +KEY_TABLE_OWNER = "Owner" +KEY_TABLE_STATISTICS = "Statistics" @dataclass class SparkConfig(AdapterConfig): - file_format: str = 'parquet' + file_format: str = "parquet" location_root: Optional[str] = None partition_by: Optional[Union[List[str], str]] = None clustered_by: Optional[Union[List[str], str]] = None @@ -44,48 +46,44 @@ class SparkConfig(AdapterConfig): class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( - 'table_database', - 'table_schema', - 'table_name', - 'table_type', - 'table_comment', - 'table_owner', - 'column_name', - 'column_index', - 'column_type', - 'column_comment', - - 'stats:bytes:label', - 'stats:bytes:value', - 'stats:bytes:description', - 'stats:bytes:include', - - 'stats:rows:label', - 'stats:rows:value', - 'stats:rows:description', - 'stats:rows:include', + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "table_owner", + "column_name", + "column_index", + "column_type", + "column_comment", + "stats:bytes:label", + "stats:bytes:value", + "stats:bytes:description", + "stats:bytes:include", + "stats:rows:label", + "stats:rows:value", + "stats:rows:description", + "stats:rows:include", ) - INFORMATION_COLUMNS_REGEX = re.compile( - r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) + INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) - INFORMATION_STATISTICS_REGEX = re.compile( - r"^Statistics: (.*)$", re.MULTILINE) + INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) HUDI_METADATA_COLUMNS = [ - '_hoodie_commit_time', - '_hoodie_commit_seqno', - '_hoodie_record_key', - '_hoodie_partition_path', - '_hoodie_file_name' + "_hoodie_commit_time", + "_hoodie_commit_seqno", + "_hoodie_record_key", + "_hoodie_partition_path", + "_hoodie_file_name", ] - Relation = SparkRelation - Column = SparkColumn - ConnectionManager = SparkConnectionManager - AdapterSpecificConfigs = SparkConfig + Relation: TypeAlias = SparkRelation + Column: TypeAlias = SparkColumn + ConnectionManager: TypeAlias = SparkConnectionManager + AdapterSpecificConfigs: TypeAlias = SparkConfig @classmethod def date_function(cls) -> str: - return 'current_timestamp()' + return "current_timestamp()" @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -109,31 +107,28 @@ def convert_datetime_type(cls, agate_table, col_idx): return "timestamp" def quote(self, identifier): - return '`{}`'.format(identifier) + return "`{}`".format(identifier) def add_schema_to_cache(self, schema) -> str: """Cache a new schema in dbt. It will show up in `list relations`.""" if schema is None: name = self.nice_connection_name() dbt.exceptions.raise_compiler_error( - 'Attempted to cache a null schema for {}'.format(name) + "Attempted to cache a null schema for {}".format(name) ) if dbt.flags.USE_CACHE: self.cache.add_schema(None, schema) # so jinja doesn't render things - return '' + return "" def list_relations_without_caching( self, schema_relation: SparkRelation ) -> List[SparkRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: - errmsg = getattr(e, 'msg', '') + errmsg = getattr(e, "msg", "") if f"Database '{schema_relation}' not found" in errmsg: return [] else: @@ -146,13 +141,12 @@ def list_relations_without_caching( if len(row) != 4: raise dbt.exceptions.RuntimeException( f'Invalid value from "show table extended ...", ' - f'got {len(row)} values, expected 4' + f"got {len(row)} values, expected 4" ) _schema, name, _, information = row - rel_type = RelationType.View \ - if 'Type: VIEW' in information else RelationType.Table - is_delta = 'Provider: delta' in information - is_hudi = 'Provider: hudi' in information + rel_type = RelationType.View if "Type: VIEW" in information else RelationType.Table + is_delta = "Provider: delta" in information + is_hudi = "Provider: hudi" in information relation = self.Relation.create( schema=_schema, identifier=name, @@ -166,7 +160,7 @@ def list_relations_without_caching( return relations def get_relation( - self, database: str, schema: str, identifier: str + self, database: Optional[str], schema: str, identifier: str ) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: database = None @@ -174,9 +168,7 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_describe_extended( - self, - relation: Relation, - raw_rows: List[agate.Row] + self, relation: Relation, raw_rows: List[agate.Row] ) -> List[SparkColumn]: # Convert the Row to a dict dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] @@ -185,44 +177,45 @@ def parse_describe_extended( pos = self.find_table_information_separator(dict_rows) # Remove rows that start with a hash, they are comments - rows = [ - row for row in raw_rows[0:pos] - if not row['col_name'].startswith('#') - ] - metadata = { - col['col_name']: col['data_type'] for col in raw_rows[pos + 1:] - } + rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] + metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) - return [SparkColumn( - table_database=None, - table_schema=relation.schema, - table_name=relation.name, - table_type=relation.type, - table_owner=str(metadata.get(KEY_TABLE_OWNER)), - table_stats=table_stats, - column=column['col_name'], - column_index=idx, - dtype=column['data_type'], - ) for idx, column in enumerate(rows)] + return [ + SparkColumn( + table_database=None, + table_schema=relation.schema, + table_name=relation.name, + table_type=relation.type, + table_owner=str(metadata.get(KEY_TABLE_OWNER)), + table_stats=table_stats, + column=column["col_name"], + column_index=idx, + dtype=column["data_type"], + ) + for idx, column in enumerate(rows) + ] @staticmethod def find_table_information_separator(rows: List[dict]) -> int: pos = 0 for row in rows: - if not row['col_name'] or row['col_name'].startswith('#'): + if not row["col_name"] or row["col_name"].startswith("#"): break pos += 1 return pos def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: - cached_relations = self.cache.get_relations( - relation.database, relation.schema) - cached_relation = next((cached_relation - for cached_relation in cached_relations - if str(cached_relation) == str(relation)), - None) + cached_relations = self.cache.get_relations(relation.database, relation.schema) + cached_relation = next( + ( + cached_relation + for cached_relation in cached_relations + if str(cached_relation) == str(relation) + ), + None, + ) columns = [] if cached_relation and cached_relation.information: columns = self.parse_columns_from_information(cached_relation) @@ -238,30 +231,21 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: # spark would throw error when table doesn't exist, where other # CDW would just return and empty list, normalizing the behavior here errmsg = getattr(e, "msg", "") - if ( - "Table or view not found" in errmsg or - "NoSuchTableException" in errmsg - ): + if "Table or view not found" in errmsg or "NoSuchTableException" in errmsg: pass else: raise e # strip hudi metadata columns. - columns = [x for x in columns - if x.name not in self.HUDI_METADATA_COLUMNS] + columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] return columns - def parse_columns_from_information( - self, relation: SparkRelation - ) -> List[SparkColumn]: - owner_match = re.findall( - self.INFORMATION_OWNER_REGEX, relation.information) + def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: + owner_match = re.findall(self.INFORMATION_OWNER_REGEX, relation.information) owner = owner_match[0] if owner_match else None - matches = re.finditer( - self.INFORMATION_COLUMNS_REGEX, relation.information) + matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, relation.information) columns = [] - stats_match = re.findall( - self.INFORMATION_STATISTICS_REGEX, relation.information) + stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, relation.information) raw_table_stats = stats_match[0] if stats_match else None table_stats = SparkColumn.convert_table_stats(raw_table_stats) for match_num, match in enumerate(matches): @@ -275,28 +259,25 @@ def parse_columns_from_information( table_owner=owner, column=column_name, dtype=column_type, - table_stats=table_stats + table_stats=table_stats, ) columns.append(column) return columns - def _get_columns_for_catalog( - self, relation: SparkRelation - ) -> Iterable[Dict[str, Any]]: + def _get_columns_for_catalog(self, relation: SparkRelation) -> Iterable[Dict[str, Any]]: columns = self.parse_columns_from_information(relation) for column in columns: # convert SparkColumns into catalog dicts as_dict = column.to_column_dict() - as_dict['column_name'] = as_dict.pop('column', None) - as_dict['column_type'] = as_dict.pop('dtype') - as_dict['table_database'] = None + as_dict["column_name"] = as_dict.pop("column", None) + as_dict["column_type"] = as_dict.pop("dtype") + as_dict["table_database"] = None yield as_dict def get_properties(self, relation: Relation) -> Dict[str, str]: properties = self.execute_macro( - FETCH_TBL_PROPERTIES_MACRO_NAME, - kwargs={'relation': relation} + FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation} ) return dict(properties) @@ -304,28 +285,30 @@ def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f'Expected only one database in get_catalog, found ' - f'{list(schema_map)}' + f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: - futures.append(tpe.submit_connected( - self, schema, - self._get_one_catalog, info, [schema], manifest - )) + futures.append( + tpe.submit_connected( + self, schema, self._get_one_catalog, info, [schema], manifest + ) + ) catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions def _get_one_catalog( - self, information_schema, schemas, manifest, + self, + information_schema, + schemas, + manifest, ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f'Expected only one schema in spark _get_one_catalog, found ' - f'{schemas}' + f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" ) database = information_schema.database @@ -335,15 +318,10 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) columns.extend(self._get_columns_for_catalog(relation)) - return agate.Table.from_object( - columns, column_types=DEFAULT_TYPE_TESTER - ) + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False return exists @@ -353,7 +331,7 @@ def get_rows_different_sql( relation_a: BaseRelation, relation_b: BaseRelation, column_names: Optional[List[str]] = None, - except_operator: str = 'EXCEPT', + except_operator: str = "EXCEPT", ) -> str: """Generate SQL for a query that returns a single row with a two columns: the number of rows that are different between the two @@ -366,7 +344,7 @@ def get_rows_different_sql( names = sorted((self.quote(c.name) for c in columns)) else: names = sorted((self.quote(n) for n in column_names)) - columns_csv = ', '.join(names) + columns_csv = ", ".join(names) sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv, @@ -384,7 +362,7 @@ def run_sql_for_tests(self, sql, fetch, conn): try: cursor.execute(sql) if fetch == "one": - if hasattr(cursor, 'fetchone'): + if hasattr(cursor, "fetchone"): return cursor.fetchone() else: # AttributeError: 'PyhiveConnectionWrapper' object has no attribute 'fetchone' @@ -406,7 +384,7 @@ def run_sql_for_tests(self, sql, fetch, conn): # "trivial". Which is true, though it seems like an unreasonable cause for # failure! It also doesn't like the `from foo, bar` syntax as opposed to # `from foo cross join bar`. -COLUMNS_EQUAL_SQL = ''' +COLUMNS_EQUAL_SQL = """ with diff_count as ( SELECT 1 as id, @@ -433,4 +411,4 @@ def run_sql_for_tests(self, sql, fetch, conn): diff_count.num_missing as num_mismatched from row_count_diff cross join diff_count -'''.strip() +""".strip() diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index 043cabfa0..249caf0d7 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -24,19 +24,19 @@ class SparkIncludePolicy(Policy): class SparkRelation(BaseRelation): quote_policy: SparkQuotePolicy = SparkQuotePolicy() include_policy: SparkIncludePolicy = SparkIncludePolicy() - quote_character: str = '`' + quote_character: str = "`" is_delta: Optional[bool] = None is_hudi: Optional[bool] = None - information: str = None + information: Optional[str] = None def __post_init__(self): if self.database != self.schema and self.database: - raise RuntimeException('Cannot set database in spark!') + raise RuntimeException("Cannot set database in spark!") def render(self): if self.include_policy.database and self.include_policy.schema: raise RuntimeException( - 'Got a spark relation with schema and database set to ' - 'include, but only one can be set' + "Got a spark relation with schema and database set to " + "include, but only one can be set" ) return super().render() diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 6010df920..beb77d548 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -4,7 +4,7 @@ import datetime as dt from types import TracebackType -from typing import Any +from typing import Any, List, Optional, Tuple from dbt.events import AdapterLogger from dbt.utils import DECIMALS @@ -25,17 +25,17 @@ class Cursor: """ def __init__(self) -> None: - self._df: DataFrame | None = None - self._rows: list[Row] | None = None + self._df: Optional[DataFrame] = None + self._rows: Optional[List[Row]] = None def __enter__(self) -> Cursor: return self def __exit__( self, - exc_type: type[BaseException] | None, - exc_val: Exception | None, - exc_tb: TracebackType | None, + exc_type: Optional[BaseException], + exc_val: Optional[Exception], + exc_tb: Optional[TracebackType], ) -> bool: self.close() return True @@ -43,13 +43,13 @@ def __exit__( @property def description( self, - ) -> list[tuple[str, str, None, None, None, None, bool]]: + ) -> List[Tuple[str, str, None, None, None, None, bool]]: """ Get the description. Returns ------- - out : list[tuple[str, str, None, None, None, None, bool]] + out : List[Tuple[str, str, None, None, None, None, bool]] The description. Source @@ -109,13 +109,13 @@ def execute(self, sql: str, *parameters: Any) -> None: spark_session = SparkSession.builder.enableHiveSupport().getOrCreate() self._df = spark_session.sql(sql) - def fetchall(self) -> list[Row] | None: + def fetchall(self) -> Optional[List[Row]]: """ Fetch all data. Returns ------- - out : list[Row] | None + out : Optional[List[Row]] The rows. Source @@ -126,7 +126,7 @@ def fetchall(self) -> list[Row] | None: self._rows = self._df.collect() return self._rows - def fetchone(self) -> Row | None: + def fetchone(self) -> Optional[Row]: """ Fetch the first output. diff --git a/dbt/include/spark/__init__.py b/dbt/include/spark/__init__.py index 564a3d1e8..b177e5d49 100644 --- a/dbt/include/spark/__init__.py +++ b/dbt/include/spark/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index e96501c45..22381d9ea 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -231,7 +231,7 @@ {% set comment = column_dict[column_name]['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation }} change column + alter table {{ relation }} change column {{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} comment '{{ escaped_comment }}'; {% endset %} @@ -260,25 +260,25 @@ {% macro spark__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} - + {% if remove_columns %} {% set platform_name = 'Delta Lake' if relation.is_delta else 'Apache Spark' %} {{ exceptions.raise_compiler_error(platform_name + ' does not support dropping columns from tables') }} {% endif %} - + {% if add_columns is none %} {% set add_columns = [] %} {% endif %} - + {% set sql -%} - + alter {{ relation.type }} {{ relation }} - + {% if add_columns %} add columns {% endif %} {% for column in add_columns %} {{ column.name }} {{ column.data_type }}{{ ',' if not loop.last }} {% endfor %} - + {%- endset -%} {% do run_query(sql) %} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index d0b6e89ba..8d8e69d93 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -1,17 +1,17 @@ {% materialization incremental, adapter='spark' -%} - + {#-- Validate early so we don't run SQL if the file_format + strategy combo is invalid --#} {%- set raw_file_format = config.get('file_format', default='parquet') -%} {%- set raw_strategy = config.get('incremental_strategy', default='append') -%} - + {%- set file_format = dbt_spark_validate_get_file_format(raw_file_format) -%} {%- set strategy = dbt_spark_validate_get_incremental_strategy(raw_strategy, file_format) -%} - + {%- set unique_key = config.get('unique_key', none) -%} {%- set partition_by = config.get('partition_by', none) -%} {%- set full_refresh_mode = (should_full_refresh()) -%} - + {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} {% set target_relation = this %} @@ -42,7 +42,7 @@ {%- endcall -%} {% do persist_docs(target_relation, model) %} - + {{ run_hooks(post_hooks) }} {{ return({'relations': [target_relation]}) }} diff --git a/dbt/include/spark/macros/materializations/incremental/strategies.sql b/dbt/include/spark/macros/materializations/incremental/strategies.sql index 215b5f3f9..28b8f2001 100644 --- a/dbt/include/spark/macros/materializations/incremental/strategies.sql +++ b/dbt/include/spark/macros/materializations/incremental/strategies.sql @@ -1,5 +1,5 @@ {% macro get_insert_overwrite_sql(source_relation, target_relation) %} - + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} insert overwrite table {{ target_relation }} @@ -41,20 +41,20 @@ {% else %} {% do predicates.append('FALSE') %} {% endif %} - + {{ sql_header if sql_header is not none }} - + merge into {{ target }} as DBT_INTERNAL_DEST using {{ source.include(schema=false) }} as DBT_INTERNAL_SOURCE on {{ predicates | join(' and ') }} - + when matched then update set {% if update_columns -%}{%- for column_name in update_columns %} {{ column_name }} = DBT_INTERNAL_SOURCE.{{ column_name }} {%- if not loop.last %}, {%- endif %} {%- endfor %} {%- else %} * {% endif %} - + when not matched then insert * {% endmacro %} diff --git a/dbt/include/spark/macros/materializations/incremental/validate.sql b/dbt/include/spark/macros/materializations/incremental/validate.sql index 3e9de359b..ffd56f106 100644 --- a/dbt/include/spark/macros/materializations/incremental/validate.sql +++ b/dbt/include/spark/macros/materializations/incremental/validate.sql @@ -28,13 +28,13 @@ Invalid incremental strategy provided: {{ raw_strategy }} You can only choose this strategy when file_format is set to 'delta' or 'hudi' {%- endset %} - + {% set invalid_insert_overwrite_delta_msg -%} Invalid incremental strategy provided: {{ raw_strategy }} You cannot use this strategy when file_format is set to 'delta' Use the 'append' or 'merge' strategy instead {%- endset %} - + {% set invalid_insert_overwrite_endpoint_msg -%} Invalid incremental strategy provided: {{ raw_strategy }} You cannot use this strategy when connecting via endpoint diff --git a/dbt/include/spark/macros/materializations/snapshot.sql b/dbt/include/spark/macros/materializations/snapshot.sql index 82d186ce2..9c891ef04 100644 --- a/dbt/include/spark/macros/materializations/snapshot.sql +++ b/dbt/include/spark/macros/materializations/snapshot.sql @@ -32,7 +32,7 @@ {% macro spark_build_snapshot_staging_table(strategy, sql, target_relation) %} {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} - + {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, schema=target_relation.schema, database=none, diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 3ae2df973..2eeb806fd 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -21,7 +21,7 @@ {% call statement('main') -%} {{ create_table_as(False, target_relation, sql) }} {%- endcall %} - + {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} diff --git a/dev-requirements.txt b/dev-requirements.txt index 0f84cbd5d..b94cb8b6b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,18 +3,22 @@ git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter +black==22.3.0 +bumpversion +click~=8.0.4 +flake8 +flaky freezegun==0.3.9 -pytest>=6.0.2 +ipdb mock>=1.3.0 -flake8 +mypy==0.950 +pre-commit +pytest-csv +pytest-dotenv +pytest-xdist +pytest>=6.0.2 pytz -bumpversion tox>=3.2.0 -ipdb -pytest-xdist -pytest-dotenv -pytest-csv -flaky # Test requirements sasl>=0.2.1 diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh index 65e6dbc97..3c3808399 100755 --- a/scripts/build-dist.sh +++ b/scripts/build-dist.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash set -eo pipefail diff --git a/setup.py b/setup.py index 12ecbacde..836aeed43 100644 --- a/setup.py +++ b/setup.py @@ -5,41 +5,39 @@ # require python 3.7 or newer if sys.version_info < (3, 7): - print('Error: dbt does not support this version of Python.') - print('Please upgrade to Python 3.7 or higher.') + print("Error: dbt does not support this version of Python.") + print("Please upgrade to Python 3.7 or higher.") sys.exit(1) # require version of setuptools that supports find_namespace_packages from setuptools import setup + try: from setuptools import find_namespace_packages except ImportError: # the user has a downlevel version of setuptools. - print('Error: dbt requires setuptools v40.1.0 or higher.') - print('Please upgrade setuptools with "pip install --upgrade setuptools" ' - 'and try again') + print("Error: dbt requires setuptools v40.1.0 or higher.") + print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again") sys.exit(1) # pull long description from README this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md'), 'r', encoding='utf8') as f: +with open(os.path.join(this_directory, "README.md"), "r", encoding="utf8") as f: long_description = f.read() # get this package's version from dbt/adapters//__version__.py def _get_plugin_version_dict(): - _version_path = os.path.join( - this_directory, 'dbt', 'adapters', 'spark', '__version__.py' - ) - _semver = r'''(?P\d+)\.(?P\d+)\.(?P\d+)''' - _pre = r'''((?Pa|b|rc)(?P
\d+))?'''
-    _version_pattern = fr'''version\s*=\s*["']{_semver}{_pre}["']'''
+    _version_path = os.path.join(this_directory, "dbt", "adapters", "spark", "__version__.py")
+    _semver = r"""(?P\d+)\.(?P\d+)\.(?P\d+)"""
+    _pre = r"""((?Pa|b|rc)(?P
\d+))?"""
+    _version_pattern = fr"""version\s*=\s*["']{_semver}{_pre}["']"""
     with open(_version_path) as f:
         match = re.search(_version_pattern, f.read().strip())
         if match is None:
-            raise ValueError(f'invalid version at {_version_path}')
+            raise ValueError(f"invalid version at {_version_path}")
         return match.groupdict()
 
 
@@ -47,7 +45,7 @@ def _get_plugin_version_dict():
 def _get_dbt_core_version():
     parts = _get_plugin_version_dict()
     minor = "{major}.{minor}.0".format(**parts)
-    pre = (parts["prekind"]+"1" if parts["prekind"] else "")
+    pre = parts["prekind"] + "1" if parts["prekind"] else ""
     return f"{minor}{pre}"
 
 
@@ -56,33 +54,28 @@ def _get_dbt_core_version():
 dbt_core_version = _get_dbt_core_version()
 description = """The Apache Spark adapter plugin for dbt"""
 
-odbc_extras = ['pyodbc>=4.0.30']
+odbc_extras = ["pyodbc>=4.0.30"]
 pyhive_extras = [
-    'PyHive[hive]>=0.6.0,<0.7.0',
-    'thrift>=0.11.0,<0.16.0',
-]
-session_extras = [
-    "pyspark>=3.0.0,<4.0.0"
+    "PyHive[hive]>=0.6.0,<0.7.0",
+    "thrift>=0.11.0,<0.16.0",
 ]
+session_extras = ["pyspark>=3.0.0,<4.0.0"]
 all_extras = odbc_extras + pyhive_extras + session_extras
 
 setup(
     name=package_name,
     version=package_version,
-
     description=description,
     long_description=long_description,
-    long_description_content_type='text/markdown',
-
-    author='dbt Labs',
-    author_email='info@dbtlabs.com',
-    url='https://github.com/dbt-labs/dbt-spark',
-
-    packages=find_namespace_packages(include=['dbt', 'dbt.*']),
+    long_description_content_type="text/markdown",
+    author="dbt Labs",
+    author_email="info@dbtlabs.com",
+    url="https://github.com/dbt-labs/dbt-spark",
+    packages=find_namespace_packages(include=["dbt", "dbt.*"]),
     include_package_data=True,
     install_requires=[
-        'dbt-core~={}'.format(dbt_core_version),
-        'sqlparams>=3.0.0',
+        "dbt-core~={}".format(dbt_core_version),
+        "sqlparams>=3.0.0",
     ],
     extras_require={
         "ODBC": odbc_extras,
@@ -92,17 +85,14 @@ def _get_dbt_core_version():
     },
     zip_safe=False,
     classifiers=[
-        'Development Status :: 5 - Production/Stable',
-        
-        'License :: OSI Approved :: Apache Software License',
-        
-        'Operating System :: Microsoft :: Windows',
-        'Operating System :: MacOS :: MacOS X',
-        'Operating System :: POSIX :: Linux',
-
-        'Programming Language :: Python :: 3.7',
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
+        "Development Status :: 5 - Production/Stable",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: Microsoft :: Windows",
+        "Operating System :: MacOS :: MacOS X",
+        "Operating System :: POSIX :: Linux",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
     ],
     python_requires=">=3.7",
 )
diff --git a/tox.ini b/tox.ini
index 59b931dca..a75e2a26a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,14 +2,6 @@
 skipsdist = True
 envlist = unit, flake8, integration-spark-thrift
 
-
-[testenv:flake8]
-basepython = python3.8
-commands = /bin/bash -c '$(which flake8) --max-line-length 99 --select=E,W,F --ignore=W504 dbt/'
-passenv = DBT_* PYTEST_ADDOPTS
-deps =
-     -r{toxinidir}/dev-requirements.txt
-
 [testenv:unit]
 basepython = python3.8
 commands = /bin/bash -c '{envpython} -m pytest -v {posargs} tests/unit'