diff --git a/dbt_common/context.py b/dbt_common/context.py index a46b1dd2..d1775c55 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -2,6 +2,7 @@ from typing import List, Mapping, Optional from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX +from dbt_common.record import Recorder class InvocationContext: @@ -9,7 +10,7 @@ def __init__(self, env: Mapping[str, str]): self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} self._env_secrets: Optional[List[str]] = None self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} - self.recorder = None + self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. @property @@ -32,7 +33,7 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def reliably_get_invocation_var() -> ContextVar: +def reliably_get_invocation_var() -> ContextVar[InvocationContext]: invocation_var: Optional[ContextVar] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 0bad081f..867d5a4c 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional +from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple import re import jsonschema from dataclasses import fields, Field @@ -26,7 +26,7 @@ class ValidationError(jsonschema.ValidationError): class DateTimeSerialization(SerializationStrategy): - def serialize(self, value) -> str: + def serialize(self, value: datetime) -> str: out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: @@ -127,7 +127,7 @@ def _get_fields(cls) -> List[Tuple[Field, str]]: # copied from hologram. Used in tests @classmethod - def _get_field_names(cls): + def _get_field_names(cls) -> List[str]: return [element[1] for element in cls._get_fields()] @@ -152,7 +152,7 @@ def validate(cls, value): # These classes must be in this order or it doesn't work class StrEnum(str, SerializableType, Enum): - def __str__(self): + def __str__(self) -> str: return self.value # https://docs.python.org/3.6/library/enum.html#using-automatic-values diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py index db619326..d966a28d 100644 --- a/dbt_common/exceptions/base.py +++ b/dbt_common/exceptions/base.py @@ -1,5 +1,5 @@ import builtins -from typing import List, Any, Optional +from typing import Any, List, Optional import os from dbt_common.constants import SECRET_ENV_PREFIX @@ -37,7 +37,7 @@ def __init__(self, msg: str): self.msg = scrub_secrets(msg, env_secrets()) @property - def type(self): + def type(self) -> str: return "Internal" def process_stack(self): @@ -59,7 +59,7 @@ def process_stack(self): return lines - def __str__(self): + def __str__(self) -> str: if hasattr(self.msg, "split"): split_msg = self.msg.split("\n") else: diff --git a/dbt_common/helper_types.py b/dbt_common/helper_types.py index 0ca435b7..8611f39f 100644 --- a/dbt_common/helper_types.py +++ b/dbt_common/helper_types.py @@ -19,7 +19,7 @@ class NVEnum(StrEnum): novalue = "novalue" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, NVEnum) @@ -59,7 +59,7 @@ def includes(self, item_name: str) -> bool: item_name in self.include or self.include in self.INCLUDE_ALL ) and item_name not in self.exclude - def _validate_items(self, items: List[str]): + def _validate_items(self, items: List[str]) -> None: pass diff --git a/dbt_common/record.py b/dbt_common/record.py index c204faa6..8fe068bb 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -14,8 +14,6 @@ from enum import Enum from typing import Any, Callable, Dict, List, Mapping, Optional, Type -from dbt_common.context import get_invocation_context - class Record: """An instance of this abstract Record class represents a request made by dbt @@ -295,9 +293,11 @@ def record_function_inner(func_to_record): return func_to_record @functools.wraps(func_to_record) - def record_replay_wrapper(*args, **kwargs): - recorder: Recorder = None + def record_replay_wrapper(*args, **kwargs) -> Any: + recorder: Optional[Recorder] = None try: + from dbt_common.context import get_invocation_context + recorder = get_invocation_context().recorder except LookupError: pass diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 951f4e8e..fbdcefa5 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import List +from typing import List, Iterable import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -74,7 +74,7 @@ def _cmp(a, b): @dataclass class VersionSpecifier(VersionSpecification): - def to_version_string(self, skip_matcher=False): + def to_version_string(self, skip_matcher: bool = False) -> str: prerelease = "" build = "" matcher = "" @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False): ) @classmethod - def from_version_string(cls, version_string): + def from_version_string(cls, version_string: str) -> "VersionSpecifier": match = _VERSION_REGEX.match(version_string) if not match: @@ -104,7 +104,7 @@ def from_version_string(cls, version_string): return cls.from_dict(matched) - def __str__(self): + def __str__(self) -> str: return self.to_version_string() def to_range(self) -> "VersionRange": @@ -192,32 +192,32 @@ def compare(self, other): return 0 - def __lt__(self, other): + def __lt__(self, other) -> bool: return self.compare(other) == -1 - def __gt__(self, other): + def __gt__(self, other) -> bool: return self.compare(other) == 1 - def __eq___(self, other): + def __eq___(self, other) -> bool: return self.compare(other) == 0 def __cmp___(self, other): return self.compare(other) @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return False @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL] @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL] @property - def is_exact(self): + def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod @@ -418,7 +418,7 @@ def reduce_versions(*args): return to_return -def versions_compatible(*args): +def versions_compatible(*args) -> bool: if len(args) == 1: return True @@ -429,7 +429,7 @@ def versions_compatible(*args): return False -def find_possible_versions(requested_range, available_versions): +def find_possible_versions(requested_range, available_versions: Iterable[str]): possible_versions = [] for version_string in available_versions: @@ -442,7 +442,9 @@ def find_possible_versions(requested_range, available_versions): return [v.to_version_string(skip_matcher=True) for v in sorted_versions] -def resolve_to_specific_version(requested_range, available_versions): +def resolve_to_specific_version( + requested_range, available_versions: Iterable[str] +) -> Optional[str]: max_version = None max_version_string = None diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index 811ea376..f366db7f 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -1,7 +1,7 @@ # This is useful for proto generated classes in particular, since # the default for protobuf for strings is the empty string, so # Optional[str] types don't work for generated Python classes. -from typing import Optional +from typing import Any, Dict, Optional def cast_to_str(string: Optional[str]) -> str: @@ -18,8 +18,8 @@ def cast_to_int(integer: Optional[int]) -> int: return integer -def cast_dict_to_dict_of_strings(dct): - new_dct = {} +def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]: + new_dct: Dict[str, str] = {} for k, v in dct.items(): new_dct[str(k)] = str(v) return new_dct diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py index 0dd8490c..529b02be 100644 --- a/dbt_common/utils/executor.py +++ b/dbt_common/utils/executor.py @@ -1,9 +1,12 @@ import concurrent.futures from contextlib import contextmanager -from contextvars import ContextVar from typing import Protocol, Optional -from dbt_common.context import get_invocation_context, reliably_get_invocation_var +from dbt_common.context import ( + get_invocation_context, + reliably_get_invocation_var, + InvocationContext, +) class ConnectingExecutor(concurrent.futures.Executor): @@ -63,7 +66,7 @@ class HasThreadingConfig(Protocol): threads: Optional[int] -def _thread_initializer(invocation_context: ContextVar) -> None: +def _thread_initializer(invocation_context: InvocationContext) -> None: invocation_var = reliably_get_invocation_var() invocation_var.set(invocation_context) diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py index 36464cbe..260ccb6a 100644 --- a/dbt_common/utils/jinja.py +++ b/dbt_common/utils/jinja.py @@ -5,19 +5,21 @@ DOCS_PREFIX = "dbt_docs__" -def get_dbt_macro_name(name): +def get_dbt_macro_name(name) -> str: if name is None: raise DbtInternalError("Got None for a macro name!") return f"{MACRO_PREFIX}{name}" -def get_dbt_docs_name(name): +def get_dbt_docs_name(name) -> str: if name is None: raise DbtInternalError("Got None for a doc name!") return f"{DOCS_PREFIX}{name}" -def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True): +def get_materialization_macro_name( + materialization_name, adapter_type=None, with_prefix=True +) -> str: if adapter_type is None: adapter_type = "default" name = f"materialization_{materialization_name}_{adapter_type}" diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py index 4c12bcd8..fff0d4c6 100644 --- a/tests/unit/test_agate_helper.py +++ b/tests/unit/test_agate_helper.py @@ -46,13 +46,13 @@ class TestAgateHelper(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() - def tearDown(self): + def tearDown(self) -> None: rmtree(self.tempdir) - def test_from_csv(self): + def test_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -61,7 +61,7 @@ def test_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_bom_from_csv(self): + def test_bom_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) @@ -70,7 +70,7 @@ def test_bom_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_from_csv_all_reserved(self): + def test_from_csv_all_reserved(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -79,7 +79,7 @@ def test_from_csv_all_reserved(self): for expected, row in zip(EXPECTED_STRINGS, tbl): self.assertEqual(list(row), expected) - def test_from_data(self): + def test_from_data(self) -> None: column_names = ["a", "b", "c", "d", "e", "f", "g"] data = [ { @@ -106,7 +106,7 @@ def test_from_data(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_datetime_formats(self): + def test_datetime_formats(self) -> None: path = os.path.join(self.tempdir, "input.csv") datetimes = [ "20180806T11:33:29.000Z", @@ -120,7 +120,7 @@ def test_datetime_formats(self): tbl = agate_helper.from_csv(path, ()) self.assertEqual(tbl[0][0], expected) - def test_merge_allnull(self): + def test_merge_allnull(self) -> None: t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) result = agate_helper.merge_tables([t1, t2]) @@ -130,7 +130,7 @@ def test_merge_allnull(self): assert isinstance(result.column_types[2], agate_helper.Integer) self.assertEqual(len(result), 4) - def test_merge_mixed(self): + def test_merge_mixed(self) -> None: t1 = agate_helper.table_from_rows( [(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d") ) @@ -181,7 +181,7 @@ def test_merge_mixed(self): assert isinstance(result.column_types[3], agate.data_types.Number) self.assertEqual(len(result), 6) - def test_nocast_string_types(self): + def test_nocast_string_types(self) -> None: # String fields should not be coerced into a representative type # See: https://github.com/dbt-labs/dbt-core/issues/2984 @@ -202,7 +202,7 @@ def test_nocast_string_types(self): for i, row in enumerate(tbl): self.assertEqual(list(row), expected[i]) - def test_nocast_bool_01(self): + def test_nocast_bool_01(self) -> None: # True and False values should not be cast to 1 and 0, and vice versa # See: https://github.com/dbt-labs/dbt-core/issues/4511 diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 817af7a2..44fc72f5 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -19,20 +19,23 @@ def test_no_retry(self): assert result == expected -def no_success_fn(): +def no_success_fn() -> str: raise RequestException("You'll never pass") return "failure" class TestMaxRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_success_fn) with pytest.raises(ConnectionError): connection_exception_retry(fn_to_retry, 3) -def single_retry_fn(): +counter = 0 + + +def single_retry_fn() -> str: global counter if counter == 0: counter += 1 @@ -45,7 +48,7 @@ def single_retry_fn(): class TestSingleRetry: - def test_no_retry(self): + def test_no_retry(self) -> None: global counter counter = 0 diff --git a/tests/unit/test_contextvars.py b/tests/unit/test_contextvars.py index 4eb58e6c..1aa9425f 100644 --- a/tests/unit/test_contextvars.py +++ b/tests/unit/test_contextvars.py @@ -1,7 +1,7 @@ from dbt_common.events.contextvars import log_contextvars, get_node_info, set_log_contextvars -def test_contextvars(): +def test_contextvars() -> None: node_info = { "unique_id": "model.test.my_model", "started_at": None, diff --git a/tests/unit/test_contracts_util.py b/tests/unit/test_contracts_util.py index 2a620370..d2fc4493 100644 --- a/tests/unit/test_contracts_util.py +++ b/tests/unit/test_contracts_util.py @@ -13,7 +13,7 @@ class ExampleMergableClass(Mergeable): class TestMergableClass(unittest.TestCase): - def test_mergeability(self): + def test_mergeability(self) -> None: mergeable1 = ExampleMergableClass( attr_a="loses", attr_b=None, attr_c=["I'll", "still", "exist"] ) diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py index 8a0e836e..7419cd8d 100644 --- a/tests/unit/test_core_dbt_utils.py +++ b/tests/unit/test_core_dbt_utils.py @@ -7,30 +7,30 @@ class TestCommonDbtUtils(unittest.TestCase): - def test_connection_exception_retry_none(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add(self), 5) + def test_connection_exception_retry_none(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) self.assertEqual(1, counter) - def test_connection_exception_retry_success_requests_exception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_requests_exception(self), 5) + def test_connection_exception_retry_success_requests_exception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry - def test_connection_exception_retry_max(self): - Counter._reset(self) + def test_connection_exception_retry_max(self) -> None: + Counter._reset() with self.assertRaises(ConnectionError): - connection_exception_retry(lambda: Counter._add_with_exception(self), 5) + connection_exception_retry(lambda: Counter._add_with_exception(), 5) self.assertEqual(6, counter) # 6 = original attempt plus 5 retries - def test_connection_exception_retry_success_failed_untar(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_untar_exception(self), 5) + def test_connection_exception_retry_success_failed_untar(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry - def test_connection_exception_retry_success_failed_eofexception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_eof_exception(self), 5) + def test_connection_exception_retry_success_failed_eofexception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry @@ -38,36 +38,42 @@ def test_connection_exception_retry_success_failed_eofexception(self): class Counter: - def _add(self): + @classmethod + def _add(cls) -> None: global counter counter += 1 # All exceptions that Requests explicitly raises inherit from # requests.exceptions.RequestException so we want to make sure that raises plus one exception # that inherit from it for sanity - def _add_with_requests_exception(self): + @classmethod + def _add_with_requests_exception(cls) -> None: global counter counter += 1 if counter < 2: raise requests.exceptions.RequestException - def _add_with_exception(self): + @classmethod + def _add_with_exception(cls) -> None: global counter counter += 1 raise requests.exceptions.ConnectionError - def _add_with_untar_exception(self): + @classmethod + def _add_with_untar_exception(cls) -> None: global counter counter += 1 if counter < 2: raise tarfile.ReadError - def _add_with_eof_exception(self): + @classmethod + def _add_with_eof_exception(cls) -> None: global counter counter += 1 if counter < 2: raise EOFError - def _reset(self): + @classmethod + def _reset(cls) -> None: global counter counter = 0 diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 791263f3..54f735e3 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,4 +1,6 @@ import json +from typing import Any, Dict + import pytest from dbt_common.record import Diff @@ -191,7 +193,7 @@ def open_mock(file, *args, **kwargs): return open_mock -def test_calculate_diff_no_diff(monkeypatch): +def test_calculate_diff_no_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ @@ -251,11 +253,11 @@ def test_calculate_diff_no_diff(monkeypatch): previous_recording_path=previous_recording_path, ) result = diff_instance.calculate_diff() - expected_result = {"GetEnvRecord": {}, "DefaultKey": {}} + expected_result: Dict[str, Any] = {"GetEnvRecord": {}, "DefaultKey": {}} assert result == expected_result -def test_calculate_diff_with_diff(monkeypatch): +def test_calculate_diff_with_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py index 80d5ae2b..f38938b6 100644 --- a/tests/unit/test_event_handler.py +++ b/tests/unit/test_event_handler.py @@ -5,7 +5,7 @@ from dbt_common.events.event_manager import TestEventManager -def test_event_logging_handler_emits_records_correctly(): +def test_event_logging_handler_emits_records_correctly() -> None: event_manager = TestEventManager() handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) log = logging.getLogger("test") @@ -27,7 +27,7 @@ def test_event_logging_handler_emits_records_correctly(): assert event_manager.event_history[5][1] == EventLevel.ERROR -def test_set_package_logging_sets_level_correctly(): +def test_set_package_logging_sets_level_correctly() -> None: event_manager = TestEventManager() log = logging.getLogger("test") set_package_logging("test", logging.DEBUG, event_manager) diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py index 1a9519de..ba98803c 100644 --- a/tests/unit/test_helper_types.py +++ b/tests/unit/test_helper_types.py @@ -1,11 +1,12 @@ import pytest +from typing import List, Union from dbt_common.helper_types import IncludeExclude, WarnErrorOptions from dbt_common.dataclass_schema import ValidationError class TestIncludeExclude: - def test_init_invalid(self): + def test_init_invalid(self) -> None: with pytest.raises(ValidationError): IncludeExclude(include="invalid") @@ -22,14 +23,16 @@ def test_init_invalid(self): (["ItemA", "ItemB"], [], True), ], ) - def test_includes(self, include, exclude, expected_includes): + def test_includes( + self, include: Union[str, List[str]], exclude: List[str], expected_includes: bool + ) -> None: include_exclude = IncludeExclude(include=include, exclude=exclude) assert include_exclude.includes("ItemA") == expected_includes class TestWarnErrorOptions: - def test_init_invalid_error(self): + def test_init_invalid_error(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) @@ -38,14 +41,14 @@ def test_init_invalid_error(self): include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) ) - def test_init_invalid_error_default_valid_error_names(self): + def test_init_invalid_error_default_valid_error_names(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"]) with pytest.raises(ValidationError): WarnErrorOptions(include="*", exclude=["InvalidError"]) - def test_init_valid_error(self): + def test_init_valid_error(self) -> None: warn_error_options = WarnErrorOptions( include=["ValidError"], valid_error_names=set(["ValidError"]) ) @@ -58,18 +61,18 @@ def test_init_valid_error(self): assert warn_error_options.include == "*" assert warn_error_options.exclude == ["ValidError"] - def test_init_default_silence(self): + def test_init_default_silence(self) -> None: my_options = WarnErrorOptions(include="*") assert my_options.silence == [] - def test_init_invalid_silence_event(self): + def test_init_invalid_silence_event(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include="*", silence=["InvalidError"]) - def test_init_valid_silence_event(self): + def test_init_valid_silence_event(self) -> None: all_events = ["MySilencedEvent"] my_options = WarnErrorOptions( - include="*", silence=all_events, valid_error_names=all_events + include="*", silence=all_events, valid_error_names=set(all_events) ) assert my_options.silence == all_events @@ -81,14 +84,16 @@ def test_init_valid_silence_event(self): ("*", ["ItemB"], True), ], ) - def test_includes(self, include, silence, expected_includes): + def test_includes( + self, include: Union[str, List[str]], silence: List[str], expected_includes: bool + ) -> None: include_exclude = WarnErrorOptions( include=include, silence=silence, valid_error_names={"ItemA", "ItemB"} ) assert include_exclude.includes("ItemA") == expected_includes - def test_silenced(self): + def test_silenced(self) -> None: my_options = WarnErrorOptions(include="*", silence=["ItemA"], valid_error_names={"ItemA"}) assert my_options.silenced("ItemA") assert not my_options.silenced("ItemB") diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index b6697f8e..3dc832d3 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -2,13 +2,13 @@ from dbt_common.context import InvocationContext -def test_invocation_context_env(): +def test_invocation_context_env() -> None: test_env = {"VAR_1": "value1", "VAR_2": "value2"} ic = InvocationContext(env=test_env) assert ic.env == test_env -def test_invocation_context_secrets(): +def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", f"{SECRET_ENV_PREFIX}VAR_2": "secret2", @@ -16,10 +16,10 @@ def test_invocation_context_secrets(): f"foo{SECRET_ENV_PREFIX}": "non-secret", } ic = InvocationContext(env=test_env) - assert set(ic.env_secrets) == set(["secret1", "secret2"]) + assert set(ic.env_secrets) == {"secret1", "secret2"} -def test_invocation_context_private(): +def test_invocation_context_private() -> None: test_env = { f"{PRIVATE_ENV_PREFIX}_VAR_1": "private1", f"{PRIVATE_ENV_PREFIX}VAR_2": "private2", diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..e906a0ac 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,23 +1,26 @@ import unittest +from dbt_common.clients._jinja_blocks import BlockTag from dbt_common.clients.jinja import extract_toplevel_blocks from dbt_common.exceptions import CompilationError class TestBlockLexer(unittest.TestCase): - def test_basic(self): + def test_basic(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" blocks = extract_toplevel_blocks( block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_multiple(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_multiple(self) -> None: body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " @@ -37,7 +40,7 @@ def test_multiple(self): ) self.assertEqual(len(blocks), 2) - def test_comments(self): + def test_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = "{# my comment #}" block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" @@ -45,12 +48,14 @@ def test_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_evil_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_evil_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = ( "{# external comment {% othertype bar %} select * from " @@ -61,12 +66,14 @@ def test_evil_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_nested_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_nested_comments(self) -> None: body = ( '{# my comment #} {{ config(foo="bar") }}' "\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n" @@ -80,33 +87,43 @@ def test_nested_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_complex_file(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_complex_file(self) -> None: blocks = extract_toplevel_blocks( complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 3) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") - self.assertEqual(blocks[0].contents, " some stuff ") - self.assertEqual(blocks[1].block_type_name, "mytype") - self.assertEqual(blocks[1].block_name, "bar") - self.assertEqual(blocks[1].full_block, bar_block) - self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) - self.assertEqual(blocks[2].block_type_name, "myothertype") - self.assertEqual(blocks[2].block_name, "x") - self.assertEqual(blocks[2].full_block, x_block.strip()) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(b0.contents, " some stuff ") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "mytype") + self.assertEqual(b1.block_name, "bar") + self.assertEqual(b1.full_block, bar_block) + self.assertEqual(b1.contents, bar_block[16:-15].rstrip()) + + b2 = blocks[2] + assert isinstance(b2, BlockTag) + self.assertEqual(b2.block_type_name, "myothertype") + self.assertEqual(b2.block_name, "x") + self.assertEqual(b2.full_block, x_block.strip()) self.assertEqual( - blocks[2].contents, + b2.contents, x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], ) - def test_peaceful_macro_coexistence(self): + def test_peaceful_macro_coexistence(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing " "{%- endmacro %} {# my model #} {% a b %} test {% enda %}" @@ -116,15 +133,22 @@ def test_peaceful_macro_coexistence(self): ) self.assertEqual(len(blocks), 4) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") - def test_macro_with_trailing_data(self): + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + + def test_macro_with_trailing_data(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} " "{# my model #} {% a b %} test {% enda %} raw data so cool" @@ -134,16 +158,24 @@ def test_macro_with_trailing_data(self): ) self.assertEqual(len(blocks), 5) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") + + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") - def test_macro_with_crazy_args(self): + def test_macro_with_crazy_args(self) -> None: body = ( """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}""" "cool{# block comment with {% endmacro %} in it #} stuff here " @@ -151,38 +183,44 @@ def test_macro_with_crazy_args(self): ) blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "macro") - self.assertEqual(blocks[0].block_name, "foo") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "macro") + self.assertEqual(b0.block_name, "foo") self.assertEqual( blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here " ) - def test_materialization_parse(self): + def test_materialization_parse(self) -> None: body = "{% materialization xxx, default %} ... {% endmaterialization %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) + b0 = blocks[0] + assert isinstance(b0, BlockTag) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) - def test_nested_not_ok(self): + def test_nested_not_ok(self) -> None: # we don't allow nesting same blocks body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_incomplete_block_failure(self): + def test_incomplete_block_failure(self) -> None: fullbody = "{% myblock foo %} {% endmyblock %}" for length in range(len("{% myblock foo %}"), len(fullbody) - 1): body = fullbody[:length] @@ -194,45 +232,45 @@ def test_wrong_end_failure(self): with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) - def test_comment_no_end_failure(self): + def test_comment_no_end_failure(self) -> None: body = "{# " with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_comment_only(self): + def test_comment_only(self) -> None: body = "{# myblock #}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_comment_block_self_closing(self): + def test_comment_block_self_closing(self) -> None: # test the case where a comment start looks a lot like it closes itself # (but it doesn't in jinja!) body = "{#} {% myblock foo %} {#}" blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_embedded_self_closing_comment_block(self): + def test_embedded_self_closing_comment_block(self) -> None: body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") - def test_set_statement(self): + def test_set_statement(self) -> None: body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_set_block(self): + def test_set_block(self) -> None: body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_set_statement(self): + def test_crazy_set_statement(self) -> None: body = ( '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% set y = otherthing("{% myblock foo %}") %}' @@ -244,19 +282,19 @@ def test_crazy_set_statement(self): self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") self.assertEqual(blocks[0].block_type_name, "otherblock") - def test_do_statement(self): + def test_do_statement(self) -> None: body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_deceptive_do_statement(self): + def test_deceptive_do_statement(self) -> None: body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_do_block(self): + def test_do_block(self) -> None: body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"do", "myblock"}, collect_raw_data=False @@ -266,7 +304,7 @@ def test_do_block(self): self.assertEqual(blocks[0].block_type_name, "do") self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_do_statement(self): + def test_crazy_do_statement(self) -> None: body = ( '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' @@ -280,7 +318,7 @@ def test_crazy_do_statement(self): self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") self.assertEqual(blocks[1].block_type_name, "myblock") - def test_awful_jinja(self): + def test_awful_jinja(self) -> None: blocks = extract_toplevel_blocks( if_you_do_this_you_are_awful, allowed_blocks={"snapshot", "materialization"}, @@ -304,63 +342,71 @@ def test_awful_jinja(self): self.assertEqual(blocks[1].block_type_name, "materialization") self.assertEqual(blocks[1].contents, "\nhi\n") - def test_quoted_endblock_within_block(self): + def test_quoted_endblock_within_block(self) -> None: body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].block_type_name, "myblock") self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') - def test_docs_block(self): + def test_docs_block(self) -> None: body = ( "{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %}" '{% docs __my_other_doc__ %} asdf "{% enddocs %}' ) blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 2) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") - self.assertEqual(blocks[0].block_name, "__my_doc__") - self.assertEqual(blocks[1].block_type_name, "docs") - self.assertEqual(blocks[1].contents, ' asdf "') - self.assertEqual(blocks[1].block_name, "__my_other_doc__") - - def test_docs_block_expr(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(b0.block_name, "__my_doc__") + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "docs") + self.assertEqual(b1.contents, ' asdf "') + self.assertEqual(b1.block_name, "__my_other_doc__") + + def test_docs_block_expr(self) -> None: body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') - self.assertEqual(blocks[0].block_name, "more_doc") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(b0.block_name, "more_doc") - def test_unclosed_model_quotes(self): + def test_unclosed_model_quotes(self) -> None: # test case for https://github.com/dbt-labs/dbt-core/issues/1533 body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "model") - self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') - self.assertEqual(blocks[0].block_name, "my_model") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "model") + self.assertEqual(b0.contents, 'select * from "something"."something_else') + self.assertEqual(b0.block_name, "my_model") - def test_if(self): + def test_if(self) -> None: # if you conditionally define your macros/models, don't body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_if_innocuous(self): + def test_if_innocuous(self) -> None: body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_for(self): + def test_for(self) -> None: # no for-loops over macros. body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_for_innocuous(self): + def test_for_innocuous(self) -> None: # no for-loops over macros. body = ( "{% for x in range(10) %}{% something my_something %} adsf " @@ -370,7 +416,7 @@ def test_for_innocuous(self): self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_endif(self): + def test_endif(self) -> None: body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -382,7 +428,7 @@ def test_endif(self): str(err.exception), ) - def test_if_endfor(self): + def test_if_endfor(self) -> None: body = "{% if x %}...{% endfor %}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -391,7 +437,7 @@ def test_if_endfor(self): str(err.exception), ) - def test_if_endfor_newlines(self): + def test_if_endfor_newlines(self) -> None: body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py index 0cc1e711..57a14438 100644 --- a/tests/unit/test_model_config.py +++ b/tests/unit/test_model_config.py @@ -14,7 +14,7 @@ class ThingWithMergeBehavior(dbtClassMixin): keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) -def test_merge_behavior_meta(): +def test_merge_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(MergeBehavior) == { @@ -29,15 +29,14 @@ def test_merge_behavior_meta(): assert existing == initial_existing -def test_merge_behavior_from_field(): - fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] - fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} - assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} - assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append - assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update - assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend +def test_merge_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields2["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields2["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields2["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["keysappended"]) == MergeBehavior.DictKeyAppend @dataclass @@ -47,7 +46,7 @@ class ThingWithShowBehavior(dbtClassMixin): shown: float = field(metadata={"show_hide": ShowBehavior.Show}) -def test_show_behavior_meta(): +def test_show_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} @@ -57,13 +56,12 @@ def test_show_behavior_meta(): assert existing == initial_existing -def test_show_behavior_from_field(): - fields = [f[0] for f in ThingWithShowBehavior._get_fields()] - fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} - assert set(fields) == {"default_behavior", "hidden", "shown"} - assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show - assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide - assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show +def test_show_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields2["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields2["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields2["shown"]) == ShowBehavior.Show @dataclass @@ -73,7 +71,7 @@ class ThingWithCompareBehavior(dbtClassMixin): excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) -def test_compare_behavior_meta(): +def test_compare_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} @@ -83,10 +81,9 @@ def test_compare_behavior_meta(): assert existing == initial_existing -def test_compare_behavior_from_field(): - fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] - fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} - assert set(fields) == {"default_behavior", "included", "excluded"} - assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude +def test_compare_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields2["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 32eb08ae..d21b5062 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -18,7 +18,7 @@ } -def test_events(): +def test_events() -> None: # M020 event event_code = "M020" event = RetryExternalCall(attempt=3, max=5) @@ -45,7 +45,7 @@ def test_events(): assert new_msg.data.attempt == msg.data.attempt -def test_extra_dict_on_event(monkeypatch): +def test_extra_dict_on_event(monkeypatch) -> None: monkeypatch.setenv("DBT_ENV_CUSTOM_ENV_env_key", "env_value") reset_metadata_vars() diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index b0371498..6e02d710 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -69,7 +69,7 @@ def setup(): os.environ["DBT_RECORDER_FILE_PATH"] = prev_fp -def test_decorator_records(setup): +def test_decorator_records(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Record" recorder = Recorder(RecorderMode.RECORD, None) set_invocation_context({}) @@ -116,7 +116,7 @@ def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: assert NotTestRecord not in recorder._records_by_type -def test_decorator_replays(setup): +def test_decorator_replays(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Replay" os.environ["DBT_RECORDER_FILE_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) diff --git a/tests/unit/test_semver.py b/tests/unit/test_semver.py index ae48e592..383d3479 100644 --- a/tests/unit/test_semver.py +++ b/tests/unit/test_semver.py @@ -1,6 +1,6 @@ import itertools import unittest -from typing import List +from typing import List, Optional from dbt_common.exceptions import VersionsNotCompatibleError from dbt_common.semver import ( @@ -23,9 +23,11 @@ def semver_regex_versioning(versions: List[str]) -> bool: return True -def create_range(start_version_string, end_version_string): - start = UnboundedVersionSpecifier() - end = UnboundedVersionSpecifier() +def create_range( + start_version_string: Optional[str], end_version_string: Optional[str] +) -> VersionRange: + start: VersionSpecifier = UnboundedVersionSpecifier() + end: VersionSpecifier = UnboundedVersionSpecifier() if start_version_string is not None: start = VersionSpecifier.from_version_string(start_version_string) @@ -37,24 +39,24 @@ def create_range(start_version_string, end_version_string): class TestSemver(unittest.TestCase): - def assertVersionSetResult(self, inputs, output_range): + def assertVersionSetResult(self, inputs, output_range) -> None: expected = create_range(*output_range) for permutation in itertools.permutations(inputs): self.assertEqual(reduce_versions(*permutation), expected) - def assertInvalidVersionSet(self, inputs): + def assertInvalidVersionSet(self, inputs) -> None: for permutation in itertools.permutations(inputs): with self.assertRaises(VersionsNotCompatibleError): reduce_versions(*permutation) - def test__versions_compatible(self): + def test__versions_compatible(self) -> None: self.assertTrue(versions_compatible("0.0.1", "0.0.1")) self.assertFalse(versions_compatible("0.0.1", "0.0.2")) self.assertTrue(versions_compatible(">0.0.1", "0.0.2")) self.assertFalse(versions_compatible("0.4.5a1", "0.4.5a2")) - def test__semver_regex_versions(self): + def test__semver_regex_versions(self) -> None: self.assertTrue( semver_regex_versioning( [ @@ -140,7 +142,7 @@ def test__semver_regex_versions(self): ) ) - def test__reduce_versions(self): + def test__reduce_versions(self) -> None: self.assertVersionSetResult(["0.0.1", "0.0.1"], ["=0.0.1", "=0.0.1"]) self.assertVersionSetResult(["0.0.1"], ["=0.0.1", "=0.0.1"]) @@ -175,7 +177,7 @@ def test__reduce_versions(self): self.assertInvalidVersionSet(["<0.0.3", ">=0.0.3"]) self.assertInvalidVersionSet(["<0.0.3", ">0.0.3"]) - def test__resolve_to_specific_version(self): + def test__resolve_to_specific_version(self) -> None: self.assertEqual( resolve_to_specific_version(create_range(">0.0.1", None), ["0.0.1", "0.0.2"]), "0.0.2" ) @@ -253,7 +255,7 @@ def test__resolve_to_specific_version(self): "0.9.1", ) - def test__filter_installable(self): + def test__filter_installable(self) -> None: installable = filter_installable( [ "1.1.0", diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py index a4dcc323..d2cf27ed 100644 --- a/tests/unit/test_system_client.py +++ b/tests/unit/test_system_client.py @@ -12,39 +12,39 @@ class SystemClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.tmp_dir = mkdtemp() self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) - def set_up_profile(self): + def set_up_profile(self) -> None: with open(self.profiles_path, "w") as f: f.write("ORIGINAL_TEXT") - def get_profile_text(self): + def get_profile_text(self) -> str: with open(self.profiles_path, "r") as f: return f.read() - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.tmp_dir) except Exception as e: # noqa: F841 pass - def test__make_file_when_exists(self): + def test__make_file_when_exists(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertFalse(written) self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") - def test__make_file_when_not_exists(self): + def test__make_file_when_not_exists(self) -> None: written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_file_with_overwrite(self): + def test__make_file_with_overwrite(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file( self.profiles_path, contents="NEW_TEXT", overwrite=True @@ -53,12 +53,12 @@ def test__make_file_with_overwrite(self): self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_dir_from_str(self): + def test__make_dir_from_str(self) -> None: test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" dbt_common.clients.system.make_directory(test_dir_str) self.assertTrue(Path(test_dir_str).is_dir()) - def test__make_dir_from_pathobj(self): + def test__make_dir_from_pathobj(self) -> None: test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") dbt_common.clients.system.make_directory(test_dir_pathobj) self.assertTrue(test_dir_pathobj.is_dir()) @@ -72,7 +72,7 @@ class TestRunCmd(unittest.TestCase): not_a_file = "zzzbbfasdfasdfsdaq" - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() self.run_dir = os.path.join(self.tempdir, "run_dir") self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") @@ -86,10 +86,10 @@ def setUp(self): with open(self.empty_file, "w") as fp: # noqa: F841 pass # "touch" - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test__executable_does_not_exist(self): + def test__executable_does_not_exist(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) @@ -99,7 +99,7 @@ def test__executable_does_not_exist(self): self.assertIn("could not find", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__not_exe(self): + def test__not_exe(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.empty_file]) @@ -112,14 +112,14 @@ def test__not_exe(self): self.assertIn("permissions", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_does_not_exist(self): + def test__cwd_does_not_exist(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) msg = str(exc.exception).lower() self.assertIn("does not exist", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__cwd_not_directory(self): + def test__cwd_not_directory(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.empty_file, self.exists_cmd) @@ -127,7 +127,7 @@ def test__cwd_not_directory(self): self.assertIn("not a directory", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_no_permissions(self): + def test__cwd_no_permissions(self) -> None: # it would be nice to add a windows test. Possible path to that is via # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on # the directory for the test user. I'm pretty sure windows users can't @@ -145,18 +145,18 @@ def test__cwd_no_permissions(self): self.assertIn("permissions", msg) self.assertIn(self.run_dir.lower(), msg) - def test__ok(self): + def test__ok(self) -> None: out, err = dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) self.assertEqual(out.strip(), b"hello") self.assertEqual(err.strip(), b"") class TestFindMatching(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) - def test_find_matching_lowercase_file_pattern(self): + def test_find_matching_lowercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -175,7 +175,7 @@ def test_find_matching_lowercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_uppercase_file_pattern(self): + def test_find_matching_uppercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -190,12 +190,12 @@ def test_find_matching_uppercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_file_pattern_not_found(self): + def test_find_matching_file_pattern_not_found(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): out = dbt_common.clients.system.find_matching(self.tempdir, [""], "*.sql") self.assertEqual(out, []) - def test_ignore_spec(self): + def test_ignore_spec(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): out = dbt_common.clients.system.find_matching( self.tempdir, @@ -207,7 +207,7 @@ def test_ignore_spec(self): ) self.assertEqual(out, []) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 @@ -215,18 +215,18 @@ def tearDown(self): class TestUntarPackage(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) self.tempdest = mkdtemp(dir=self.base_dir) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 pass - def test_untar_package_success(self): + def test_untar_package_success(self) -> None: # set up a valid tarball to test against with NamedTemporaryFile( prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False @@ -244,7 +244,7 @@ def test_untar_package_success(self): path = Path(os.path.join(self.tempdest, relative_file_a)) assert path.is_file() - def test_untar_package_failure(self): + def test_untar_package_failure(self) -> None: # create a text file then rename it as a tar (so it's invalid) with NamedTemporaryFile( prefix="a", suffix=".txt", dir=self.tempdir, delete=False @@ -259,7 +259,7 @@ def test_untar_package_failure(self): with self.assertRaises(tarfile.ReadError) as exc: # noqa: F841 dbt_common.clients.system.untar_package(tar_file_path, self.tempdest) - def test_untar_package_empty(self): + def test_untar_package_empty(self) -> None: # create a tarball with nothing in it with NamedTemporaryFile( prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir diff --git a/tests/unit/test_ui.py b/tests/unit/test_ui.py index 22e431d5..5b70b1d1 100644 --- a/tests/unit/test_ui.py +++ b/tests/unit/test_ui.py @@ -1,11 +1,11 @@ from dbt_common.ui import warning_tag, error_tag -def test_warning_tag(): +def test_warning_tag() -> None: tagged = warning_tag("hi") assert "WARNING" in tagged -def test_error_tag(): +def test_error_tag() -> None: tagged = error_tag("hi") assert "ERROR" in tagged diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 250c20cc..93c57046 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,7 +5,7 @@ class TestDeepMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -27,7 +27,7 @@ def test__simple_cases(self): class TestMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -49,7 +49,7 @@ def test__simple_cases(self): class TestDeepMap(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.input_value = { "foo": { "bar": "hello", @@ -74,7 +74,7 @@ def intify_all(value, _): except (TypeError, ValueError): return -1 - def test__simple_cases(self): + def test__simple_cases(self) -> None: expected = { "foo": { "bar": -1, @@ -104,7 +104,7 @@ def special_keypath(value, keypath): else: return value - def test__keypath(self): + def test__keypath(self) -> None: expected = { "foo": { "bar": "hello", @@ -128,11 +128,11 @@ def test__keypath(self): actual = dbt_common.utils.dict.deep_map_render(self.special_keypath, expected) self.assertEqual(actual, expected) - def test__noop(self): + def test__noop(self) -> None: actual = dbt_common.utils.dict.deep_map_render(lambda x, _: x, self.input_value) self.assertEqual(actual, self.input_value) - def test_trivial(self): + def test_trivial(self) -> None: cases = [[], {}, 1, "abc", None, True] for case in cases: result = dbt_common.utils.dict.deep_map_render(lambda x, _: x, case)