From e19badc5f39cfcc43dfd30c66bf9339781150109 Mon Sep 17 00:00:00 2001 From: Oliver Tosky Date: Tue, 12 Nov 2024 00:41:51 -0500 Subject: [PATCH] address mypy errors --- src/dbt_score/evaluation.py | 5 ++++- src/dbt_score/more_itertools.py | 32 ++++++++++++++++++++++++++++- src/dbt_score/rule.py | 36 +++++++++++++++++++++++---------- src/dbt_score/rule_filter.py | 24 +++++++++++++--------- tests/conftest.py | 4 ++-- 5 files changed, 77 insertions(+), 24 deletions(-) diff --git a/src/dbt_score/evaluation.py b/src/dbt_score/evaluation.py index 98a7081..bb29f03 100644 --- a/src/dbt_score/evaluation.py +++ b/src/dbt_score/evaluation.py @@ -3,7 +3,7 @@ from __future__ import annotations from itertools import chain -from typing import Type +from typing import Type, cast from dbt_score.formatters import Formatter from dbt_score.models import Evaluable, ManifestLoader @@ -57,6 +57,9 @@ def evaluate(self) -> None: for evaluable in chain( self._manifest_loader.models, self._manifest_loader.sources ): + # type inference on elements from `chain` is wonky + # and resolves to superclass HasColumnsMixin + evaluable = cast(Evaluable, evaluable) self.results[evaluable] = {} for rule in rules: try: diff --git a/src/dbt_score/more_itertools.py b/src/dbt_score/more_itertools.py index f91ff6c..e1d09a5 100644 --- a/src/dbt_score/more_itertools.py +++ b/src/dbt_score/more_itertools.py @@ -1,7 +1,37 @@ """Vendored utility functions from https://github.com/more-itertools/more-itertools.""" +from typing import ( + Callable, + Iterable, + Optional, + TypeVar, + overload, +) +_T = TypeVar("_T") +_U = TypeVar("_U") -def first_true(iterable, default=None, pred=None): + +@overload +def first_true( + iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ... +) -> _T | None: + ... + + +@overload +def first_true( + iterable: Iterable[_T], + default: _U, + pred: Callable[[_T], object] | None = ..., +) -> _T | _U: + ... + + +def first_true( + iterable: Iterable[_T], + default: Optional[_U] = None, + pred: Optional[Callable[[_T], object]] = None, +) -> _T | _U | None: """Returns the first true value in the iterable. If no true value is found, returns *default* diff --git a/src/dbt_score/rule.py b/src/dbt_score/rule.py index d05e868..e01ce55 100644 --- a/src/dbt_score/rule.py +++ b/src/dbt_score/rule.py @@ -4,9 +4,17 @@ import typing from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Iterable, Type, TypeAlias, overload - -from dbt_score.models import Evaluable +from typing import ( + Any, + Callable, + Iterable, + Type, + TypeAlias, + cast, + overload, +) + +from dbt_score.models import Evaluable, Model, Source from dbt_score.more_itertools import first_true from dbt_score.rule_filter import RuleFilter @@ -55,7 +63,9 @@ class RuleViolation: message: str | None = None -RuleEvaluationType: TypeAlias = Callable[[Evaluable], RuleViolation | None] +ModelRuleEvaluationType: TypeAlias = Callable[[Model], RuleViolation | None] +SourceRuleEvaluationType: TypeAlias = Callable[[Source], RuleViolation | None] +RuleEvaluationType: TypeAlias = ModelRuleEvaluationType | SourceRuleEvaluationType class Rule: @@ -66,7 +76,7 @@ class Rule: rule_filter_names: list[str] rule_filters: frozenset[RuleFilter] = frozenset() default_config: typing.ClassVar[dict[str, Any]] = {} - resource_type: typing.ClassVar[Evaluable] + resource_type: typing.ClassVar[type[Evaluable]] def __init__(self, rule_config: RuleConfig | None = None) -> None: """Initialize the rule.""" @@ -85,7 +95,7 @@ def __init_subclass__(cls, **kwargs) -> None: # type: ignore cls._validate_rule_filters() @classmethod - def _validate_rule_filters(cls): + def _validate_rule_filters(cls) -> None: for rule_filter in cls.rule_filters: if rule_filter.resource_type != cls.resource_type: raise TypeError( @@ -111,7 +121,8 @@ def _introspect_resource_type(cls) -> Type[Evaluable]: "annotated Model or Source argument." ) - return resource_type_argument.annotation + resource_type = cast(type[Evaluable], resource_type_argument.annotation) + return resource_type def process_config(self, rule_config: RuleConfig) -> None: """Process the rule config.""" @@ -178,7 +189,12 @@ def __hash__(self) -> int: @overload -def rule(__func: RuleEvaluationType) -> Type[Rule]: +def rule(__func: ModelRuleEvaluationType) -> Type[Rule]: + ... + + +@overload +def rule(__func: SourceRuleEvaluationType) -> Type[Rule]: ... @@ -214,9 +230,7 @@ def rule( rule_filters: Set of RuleFilter that filters the items that the rule applies to. """ - def decorator_rule( - func: RuleEvaluationType, - ) -> Type[Rule]: + def decorator_rule(func: RuleEvaluationType) -> Type[Rule]: """Decorator function.""" if func.__doc__ is None and description is None: raise AttributeError("Rule must define `description` or `func.__doc__`.") diff --git a/src/dbt_score/rule_filter.py b/src/dbt_score/rule_filter.py index dd4964e..c8e0e46 100644 --- a/src/dbt_score/rule_filter.py +++ b/src/dbt_score/rule_filter.py @@ -2,19 +2,21 @@ import inspect import typing -from typing import Any, Callable, Type, TypeAlias, overload +from typing import Any, Callable, Type, TypeAlias, cast, overload -from dbt_score.models import Evaluable +from dbt_score.models import Evaluable, Model, Source from dbt_score.more_itertools import first_true -FilterEvaluationType: TypeAlias = Callable[[Evaluable], bool] +ModelFilterEvaluationType: TypeAlias = Callable[[Model], bool] +SourceFilterEvaluationType: TypeAlias = Callable[[Source], bool] +FilterEvaluationType: TypeAlias = ModelFilterEvaluationType | SourceFilterEvaluationType class RuleFilter: """The Filter base class.""" description: str - resource_type: typing.ClassVar[Evaluable] + resource_type: typing.ClassVar[type[Evaluable]] def __init__(self) -> None: """Initialize the filter.""" @@ -44,7 +46,8 @@ def _introspect_resource_type(cls) -> Type[Evaluable]: "annotated Model or Source argument." ) - return resource_type_argument.annotation + resource_type = cast(type[Evaluable], resource_type_argument.annotation) + return resource_type def evaluate(self, evaluable: Evaluable) -> bool: """Evaluates the filter.""" @@ -65,7 +68,12 @@ def __hash__(self) -> int: @overload -def rule_filter(__func: FilterEvaluationType) -> Type[RuleFilter]: +def rule_filter(__func: ModelFilterEvaluationType) -> Type[RuleFilter]: + ... + + +@overload +def rule_filter(__func: SourceFilterEvaluationType) -> Type[RuleFilter]: ... @@ -96,9 +104,7 @@ def rule_filter( description: The description of the filter. """ - def decorator_filter( - func: FilterEvaluationType, - ) -> Type[RuleFilter]: + def decorator_filter(func: FilterEvaluationType) -> Type[RuleFilter]: """Decorator function.""" if func.__doc__ is None and description is None: raise AttributeError( diff --git a/tests/conftest.py b/tests/conftest.py index 9022730..fec3163 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,7 +77,7 @@ def model2(raw_manifest) -> Model: @fixture -def source1(raw_manifest) -> Model: +def source1(raw_manifest) -> Source: """Source 1.""" return Source.from_node( raw_manifest["sources"]["source.package.my_source.table1"], [] @@ -85,7 +85,7 @@ def source1(raw_manifest) -> Model: @fixture -def source2(raw_manifest) -> Model: +def source2(raw_manifest) -> Source: """Source 2.""" return Source.from_node( raw_manifest["sources"]["source.package.my_source.table2"], []