diff --git a/CHANGELOG.md b/CHANGELOG.md index cc21d4f7..af418d24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Removed pins for pandas and pydantic to ease user burdens ([874](https://github.com/databricks/dbt-databricks/pull/874)) - Add more relation types to make codegen happy ([875](https://github.com/databricks/dbt-databricks/pull/875)) +- add UP ruleset ([865](https://github.com/databricks/dbt-databricks/pull/865)) ## dbt-databricks 1.9.0 (December 9, 2024) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index cfabb235..57bc4ca3 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -246,7 +246,7 @@ def _poll_api( @dataclass(frozen=True, eq=True, unsafe_hash=True) -class CommandExecution(object): +class CommandExecution: command_id: str context_id: str cluster_id: str diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index df2cdb2d..373d6bd7 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -28,7 +28,7 @@ def data_type(self) -> str: return self.translate_type(self.dtype) def __repr__(self) -> str: - return "".format(self.name, self.data_type) + return f"" @staticmethod def get_name(column: dict[str, Any]) -> str: diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 33ce5760..509686d7 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -5,13 +5,13 @@ import time import uuid import warnings -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Hashable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass from multiprocessing.context import SpawnContext from numbers import Number from threading import get_ident -from typing import TYPE_CHECKING, Any, Hashable, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -488,7 +488,7 @@ def add_query( try: log_sql = redact_credentials(sql) if abridge_sql_log: - log_sql = "{}...".format(log_sql[:512]) + log_sql = f"{log_sql[:512]}..." fire_event( SQLQuery( diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index 0f729a5b..7a318cad 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -136,20 +136,16 @@ def __post_init__(self) -> None: def validate_creds(self) -> None: for key in ["host", "http_path"]: if not getattr(self, key): - raise DbtConfigError( - "The config '{}' is required to connect to Databricks".format(key) - ) + raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") if not self.token and self.auth_type != "oauth": raise DbtConfigError( - ("The config `auth_type: oauth` is required when not using access token") + "The config `auth_type: oauth` is required when not using access token" ) if not self.client_id and self.client_secret: raise DbtConfigError( - ( - "The config 'client_id' is required to connect " - "to Databricks when 'client_secret' is present" - ) + "The config 'client_id' is required to connect " + "to Databricks when 'client_secret' is present" ) @classmethod diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index f68664e8..dce432c9 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -539,7 +539,7 @@ def _get_catalog_for_relation_map( used_schemas: frozenset[tuple[str, str]], ) -> tuple["Table", list[Exception]]: with executor(self.config) as tpe: - futures: list[Future["Table"]] = [] + futures: list[Future[Table]] = [] for schema, relations in relation_map.items(): if schema in used_schemas: identifier = get_identifier_list_string(relations) @@ -804,7 +804,7 @@ def get_from_relation( ) -> DatabricksRelationConfig: """Get the relation config from the relation.""" - relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation) + relation_config = super().get_from_relation(adapter, relation) # Ensure any current refreshes are completed before returning the relation config tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"]) diff --git a/dbt/adapters/databricks/python_models/run_tracking.py b/dbt/adapters/databricks/python_models/run_tracking.py index e48dae7d..4b4fea41 100644 --- a/dbt/adapters/databricks/python_models/run_tracking.py +++ b/dbt/adapters/databricks/python_models/run_tracking.py @@ -6,7 +6,7 @@ from dbt.adapters.databricks.logging import logger -class PythonRunTracker(object): +class PythonRunTracker: _run_ids: set[str] = set() _commands: set[CommandExecution] = set() _lock = threading.Lock() diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index f69f02f5..cfac5125 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Optional, Type +from typing import Any, Optional, Type # noqa from dbt_common.dataclass_schema import StrEnum from dbt_common.exceptions import DbtRuntimeError @@ -133,7 +133,7 @@ def matches( return match @classproperty - def get_relation_type(cls) -> Type[DatabricksRelationType]: + def get_relation_type(cls) -> Type[DatabricksRelationType]: # noqa return DatabricksRelationType def information_schema(self, view_name: Optional[str] = None) -> InformationSchema: diff --git a/pyproject.toml b/pyproject.toml index 13a95fd7..d2f728d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "ruff", "types-requests", "debugpy", + "pydantic>=1.10.0, <2", ] path = ".hatch" python = "3.9" @@ -101,7 +102,7 @@ line-length = 100 target-version = 'py39' [tool.ruff.lint] -select = ["E", "W", "F", "I"] +select = ["E", "W", "F", "I", "UP"] ignore = ["E203"] [tool.pytest.ini_options] diff --git a/tests/functional/adapter/ephemeral/test_ephemeral.py b/tests/functional/adapter/ephemeral/test_ephemeral.py index c00585b8..52efdc0f 100644 --- a/tests/functional/adapter/ephemeral/test_ephemeral.py +++ b/tests/functional/adapter/ephemeral/test_ephemeral.py @@ -33,7 +33,7 @@ def test_ephemeral_nested(self, project): results = util.run_dbt(["run"]) assert len(results) == 2 assert os.path.exists("./target/run/test/models/root_view.sql") - with open("./target/run/test/models/root_view.sql", "r") as fp: + with open("./target/run/test/models/root_view.sql") as fp: sql_file = fp.read() sql_file = re.sub(r"\d+", "", sql_file) diff --git a/tests/functional/adapter/hooks/test_model_hooks.py b/tests/functional/adapter/hooks/test_model_hooks.py index 9a3ba61e..bf1a4e6c 100644 --- a/tests/functional/adapter/hooks/test_model_hooks.py +++ b/tests/functional/adapter/hooks/test_model_hooks.py @@ -49,7 +49,7 @@ def get_ctx_vars(self, state, count, project): "invocation_id", "thread_id", ] - field_list = ", ".join(["{}".format(f) for f in fields]) + field_list = ", ".join([f"{f}" for f in fields]) query = ( f"select {field_list} from {project.test_schema}.on_model_hook" f" where test_state = '{state}'" diff --git a/tests/functional/adapter/hooks/test_run_hooks.py b/tests/functional/adapter/hooks/test_run_hooks.py index 5c7dd5c2..1f133d86 100644 --- a/tests/functional/adapter/hooks/test_run_hooks.py +++ b/tests/functional/adapter/hooks/test_run_hooks.py @@ -65,7 +65,7 @@ def get_ctx_vars(self, state, project): "invocation_id", "thread_id", ] - field_list = ", ".join(["{}".format(f) for f in fields]) + field_list = ", ".join([f"{f}" for f in fields]) query = ( f"select {field_list} from {project.test_schema}.on_run_hook where test_state = " f"'{state}'" diff --git a/tests/functional/adapter/python_model/test_python_model.py b/tests/functional/adapter/python_model/test_python_model.py index 858214b7..726791df 100644 --- a/tests/functional/adapter/python_model/test_python_model.py +++ b/tests/functional/adapter/python_model/test_python_model.py @@ -58,7 +58,7 @@ def test_changing_schema_with_log_validation(self, project, logs_dir): ) util.run_dbt(["run"]) log_file = os.path.join(logs_dir, "dbt.log") - with open(log_file, "r") as f: + with open(log_file) as f: log = f.read() # validate #5510 log_code_execution works assert "On model.test.simple_python_model:" in log diff --git a/tests/unit/api_client/test_workspace_api.py b/tests/unit/api_client/test_workspace_api.py index 322a9172..208e3324 100644 --- a/tests/unit/api_client/test_workspace_api.py +++ b/tests/unit/api_client/test_workspace_api.py @@ -36,7 +36,7 @@ def test_upload_notebook__non_200(self, api, session): def test_upload_notebook__200(self, api, session, host): session.post.return_value.status_code = 200 - encoded = base64.b64encode("code".encode()).decode() + encoded = base64.b64encode(b"code").decode() api.upload_notebook("path", "code") session.post.assert_called_once_with( f"https://{host}/api/2.0/workspace/import", diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index f76ed182..424cd39f 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -165,7 +165,7 @@ def get_password(self, servicename, username): if not os.path.exists(file_path): return None - with open(file_path, "r") as file: + with open(file_path) as file: password = file.read() return password