Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More cleanup work for forward compatibility #865

Merged
merged 6 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

- Removed pins for pandas and pydantic to ease user burdens ([874](https://github.com/databricks/dbt-databricks/pull/874))
- Add more relation types to make codegen happy ([875](https://github.com/databricks/dbt-databricks/pull/875))
- add UP ruleset ([865](https://github.com/databricks/dbt-databricks/pull/865))

## dbt-databricks 1.9.0 (December 9, 2024)

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _poll_api(


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class CommandExecution(object):
class CommandExecution:
command_id: str
context_id: str
cluster_id: str
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def data_type(self) -> str:
return self.translate_type(self.dtype)

def __repr__(self) -> str:
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)
return f"<DatabricksColumn {self.name} ({self.data_type})>"

@staticmethod
def get_name(column: dict[str, Any]) -> str:
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import time
import uuid
import warnings
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Hashable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing.context import SpawnContext
from numbers import Number
from threading import get_ident
from typing import TYPE_CHECKING, Any, Hashable, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
Expand Down Expand Up @@ -488,7 +488,7 @@ def add_query(
try:
log_sql = redact_credentials(sql)
if abridge_sql_log:
log_sql = "{}...".format(log_sql[:512])
log_sql = f"{log_sql[:512]}..."

fire_event(
SQLQuery(
Expand Down
12 changes: 4 additions & 8 deletions dbt/adapters/databricks/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,16 @@ def __post_init__(self) -> None:
def validate_creds(self) -> None:
for key in ["host", "http_path"]:
if not getattr(self, key):
raise DbtConfigError(
"The config '{}' is required to connect to Databricks".format(key)
)
raise DbtConfigError(f"The config '{key}' is required to connect to Databricks")
if not self.token and self.auth_type != "oauth":
raise DbtConfigError(
("The config `auth_type: oauth` is required when not using access token")
"The config `auth_type: oauth` is required when not using access token"
)

if not self.client_id and self.client_secret:
raise DbtConfigError(
(
"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
)
"The config 'client_id' is required to connect "
"to Databricks when 'client_secret' is present"
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _get_catalog_for_relation_map(
used_schemas: frozenset[tuple[str, str]],
) -> tuple["Table", list[Exception]]:
with executor(self.config) as tpe:
futures: list[Future["Table"]] = []
futures: list[Future[Table]] = []
for schema, relations in relation_map.items():
if schema in used_schemas:
identifier = get_identifier_list_string(relations)
Expand Down Expand Up @@ -804,7 +804,7 @@ def get_from_relation(
) -> DatabricksRelationConfig:
"""Get the relation config from the relation."""

relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation)
relation_config = super().get_from_relation(adapter, relation)

# Ensure any current refreshes are completed before returning the relation config
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/python_models/run_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt.adapters.databricks.logging import logger


class PythonRunTracker(object):
class PythonRunTracker:
_run_ids: set[str] = set()
_commands: set[CommandExecution] = set()
_lock = threading.Lock()
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/databricks/relation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Optional, Type
from typing import Any, Optional, Type # noqa

from dbt_common.dataclass_schema import StrEnum
from dbt_common.exceptions import DbtRuntimeError
Expand Down Expand Up @@ -133,7 +133,7 @@ def matches(
return match

@classproperty
def get_relation_type(cls) -> Type[DatabricksRelationType]:
def get_relation_type(cls) -> Type[DatabricksRelationType]: # noqa
return DatabricksRelationType

def information_schema(self, view_name: Optional[str] = None) -> InformationSchema:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ dependencies = [
"ruff",
"types-requests",
"debugpy",
"pydantic>=1.10.0, <2",
]
path = ".hatch"
python = "3.9"
Expand All @@ -101,7 +102,7 @@ line-length = 100
target-version = 'py39'

[tool.ruff.lint]
select = ["E", "W", "F", "I"]
select = ["E", "W", "F", "I", "UP"]
ignore = ["E203"]

[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/ephemeral/test_ephemeral.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_ephemeral_nested(self, project):
results = util.run_dbt(["run"])
assert len(results) == 2
assert os.path.exists("./target/run/test/models/root_view.sql")
with open("./target/run/test/models/root_view.sql", "r") as fp:
with open("./target/run/test/models/root_view.sql") as fp:
sql_file = fp.read()

sql_file = re.sub(r"\d+", "", sql_file)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/hooks/test_model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_ctx_vars(self, state, count, project):
"invocation_id",
"thread_id",
]
field_list = ", ".join(["{}".format(f) for f in fields])
field_list = ", ".join([f"{f}" for f in fields])
query = (
f"select {field_list} from {project.test_schema}.on_model_hook"
f" where test_state = '{state}'"
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/hooks/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_ctx_vars(self, state, project):
"invocation_id",
"thread_id",
]
field_list = ", ".join(["{}".format(f) for f in fields])
field_list = ", ".join([f"{f}" for f in fields])
query = (
f"select {field_list} from {project.test_schema}.on_run_hook where test_state = "
f"'{state}'"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_changing_schema_with_log_validation(self, project, logs_dir):
)
util.run_dbt(["run"])
log_file = os.path.join(logs_dir, "dbt.log")
with open(log_file, "r") as f:
with open(log_file) as f:
log = f.read()
# validate #5510 log_code_execution works
assert "On model.test.simple_python_model:" in log
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/api_client/test_workspace_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_upload_notebook__non_200(self, api, session):

def test_upload_notebook__200(self, api, session, host):
session.post.return_value.status_code = 200
encoded = base64.b64encode("code".encode()).decode()
encoded = base64.b64encode(b"code").decode()
api.upload_notebook("path", "code")
session.post.assert_called_once_with(
f"https://{host}/api/2.0/workspace/import",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_password(self, servicename, username):
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
with open(file_path) as file:
password = file.read()

return password
Expand Down
Loading