diff --git a/pyproject.toml b/pyproject.toml index ed07493..db05f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,9 @@ max-complexity = 10 [tool.ruff.lint.pydocstyle] convention = "google" +[tool.ruff.lint.pylint] +max-args = 6 + ### Coverage ### [tool.coverage.run] diff --git a/src/dbt_score/rule.py b/src/dbt_score/rule.py index 2170b20..23f4f51 100644 --- a/src/dbt_score/rule.py +++ b/src/dbt_score/rule.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Type +from typing import Any, Callable, Type, TypeAlias, overload from dbt_score.models import Model @@ -23,6 +23,9 @@ class RuleViolation: message: str | None = None +RuleEvaluationType: TypeAlias = Callable[[Model], RuleViolation | None] + + class Rule: """The rule base class.""" @@ -40,21 +43,46 @@ def evaluate(self, model: Model) -> RuleViolation | None: raise NotImplementedError("Subclass must implement method `evaluate`.") +# Use @overload to have proper typing for both @rule and @rule(...). +# https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories + + +@overload +def rule(__func: RuleEvaluationType) -> Type[Rule]: + ... + + +@overload def rule( - description: str | None = None, + *, + description: str | RuleEvaluationType | None = None, severity: Severity = Severity.MEDIUM, -) -> Callable[[Callable[[Model], RuleViolation | None]], Type[Rule]]: +) -> Callable[[RuleEvaluationType], Type[Rule]]: + ... + + +def rule( + __func: RuleEvaluationType | None = None, + *, + description: str | RuleEvaluationType | None = None, + severity: Severity = Severity.MEDIUM, +) -> Type[Rule] | Callable[[RuleEvaluationType], Type[Rule]]: """Rule decorator. The rule decorator creates a rule class (subclass of Rule) and returns it. + Using arguments or not are both supported: + - ``@rule`` + - ``@rule(description="...")`` + Args: + __func: The rule evaluation function being decorated. description: The description of the rule. severity: The severity of the rule. """ def decorator_rule( - func: Callable[[Model], RuleViolation | None], + func: RuleEvaluationType, ) -> Type[Rule]: """Decorator function.""" if func.__doc__ is None and description is None: @@ -82,4 +110,9 @@ def wrapped_func(self: Rule, *args: Any, **kwargs: Any) -> RuleViolation | None: return rule_class - return decorator_rule + if __func is not None: + # The syntax @rule is used + return decorator_rule(__func) + else: + # The syntax @rule(...) is used + return decorator_rule diff --git a/src/dbt_score/rules/generic.py b/src/dbt_score/rules/generic.py index 0a0c6d8..e6a2cd7 100644 --- a/src/dbt_score/rules/generic.py +++ b/src/dbt_score/rules/generic.py @@ -5,14 +5,14 @@ # mypy: disable-error-code="return" -@rule() +@rule def has_description(model: Model) -> RuleViolation | None: """A model should have a description.""" if not model.description: return RuleViolation(message="Model lacks a description.") -@rule() +@rule def columns_have_description(model: Model) -> RuleViolation | None: """All columns of a model should have a description.""" invalid_column_names = [ diff --git a/tests/conftest.py b/tests/conftest.py index 782c622..588621d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,6 +51,33 @@ def example_rule(model: Model) -> RuleViolation | None: return example_rule +@fixture +def decorator_rule_no_parens() -> Type[Rule]: + """An example rule created with the rule decorator without parentheses.""" + + @rule + def example_rule(model: Model) -> RuleViolation | None: + """Description of the rule.""" + if model.name == "model1": + return RuleViolation(message="Model1 is a violation.") + return None + + return example_rule + + +@fixture +def decorator_rule_args() -> Type[Rule]: + """An example rule created with the rule decorator with arguments.""" + + @rule(description="Description of the rule.") + def example_rule(model: Model) -> RuleViolation | None: + if model.name == "model1": + return RuleViolation(message="Model1 is a violation.") + return None + + return example_rule + + @fixture def class_rule() -> Type[Rule]: """An example rule created with a class.""" diff --git a/tests/test_rule.py b/tests/test_rule.py index febdd7b..9819e75 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -5,9 +5,18 @@ from dbt_score.rule import Rule, RuleViolation, Severity, rule -def test_rule_decorator_and_class(decorator_rule, class_rule, model1, model2): +def test_rule_decorator_and_class( + decorator_rule, + decorator_rule_no_parens, + decorator_rule_args, + class_rule, + model1, + model2, +): """Test rule creation with the rule decorator and class.""" decorator_rule_instance = decorator_rule() + decorator_rule_no_parens_instance = decorator_rule_no_parens() + decorator_rule_args_instance = decorator_rule_args() class_rule_instance = class_rule() def assertions(rule_instance): @@ -20,6 +29,8 @@ def assertions(rule_instance): assert rule_instance.evaluate(model2) is None assertions(decorator_rule_instance) + assertions(decorator_rule_no_parens_instance) + assertions(decorator_rule_args_instance) assertions(class_rule_instance)