diff --git a/CHANGELOG.md b/CHANGELOG.md index 920e6d3..211565a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,16 @@ and this project adheres to ## [Unreleased] +- Support linting of sources. +- **Breaking**: Renamed modules: `dbt_score.model_filter` becomes + `dbt_score.rule_filter` +- **Breaking**: Renamed filter class and decorator: `@model_filter` becomes + `@rule_filter` and `ModelFilter` becomes `RuleFilter`. +- **Breaking**: Config option `model_filter_names` becomes `rule_filter_names`. +- **Breaking**: CLI flag naming fixes: `--fail_any_model_under` becomes + `--fail-any-item-under` and `--fail_project_under` becomes + `--fail-project-under`. + ## [0.7.1] - 2024-11-01 - Fix mkdocs. diff --git a/README.md b/README.md index dcea0eb..50a0853 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ ## What is `dbt-score`? -`dbt-score` is a linter for dbt model metadata. +`dbt-score` is a linter for dbt metadata. [dbt][dbt] (Data Build Tool) is a great framework for creating, building, organizing, testing and documenting _data models_, i.e. data sets living in a diff --git a/docs/configuration.md b/docs/configuration.md index 008dd17..a64df37 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -18,7 +18,7 @@ rule_namespaces = ["dbt_score.rules", "dbt_score_rules", "custom_rules"] disabled_rules = ["dbt_score.rules.generic.columns_have_description"] inject_cwd_in_python_path = true fail_project_under = 7.5 -fail_any_model_under = 8.0 +fail_any_item_under = 8.0 [tool.dbt-score.badges] first.threshold = 10.0 @@ -51,8 +51,8 @@ The following options can be set in the `pyproject.toml` file: - `disabled_rules`: A list of rules to disable. - `fail_project_under` (default: `5.0`): If the project score is below this value the command will fail with return code 1. -- `fail_any_model_under` (default: `5.0`): If any model scores below this value - the command will fail with return code 1. +- `fail_any_item_under` (default: `5.0`): If any model or source scores below + this value the command will fail with return code 1. #### Badges configuration @@ -70,7 +70,7 @@ All badges except `wip` can be configured with the following option: - `threshold`: The threshold for the badge. A decimal number between `0.0` and `10.0` that will be used to compare to the score. The threshold is the minimum - score required for a model to be rewarded with a certain badge. + score required for a model or source to be rewarded with a certain badge. The default values can be found in the [BadgeConfig](reference/config.md#dbt_score.config.BadgeConfig). @@ -86,7 +86,7 @@ Every rule can be configured with the following option: - `severity`: The severity of the rule. Rules have a default severity and can be overridden. It's an integer with a minimum value of 1 and a maximum value of 4. -- `model_filter_names`: Filters used by the rule. Takes a list of names that can +- `rule_filter_names`: Filters used by the rule. Takes a list of names that can be found in the same namespace as the rules (see [Package rules](package_rules.md)). diff --git a/docs/create_rules.md b/docs/create_rules.md index cd5e5b4..dc6c6df 100644 --- a/docs/create_rules.md +++ b/docs/create_rules.md @@ -1,9 +1,9 @@ # Create rules -In order to lint and score models, `dbt-score` uses a set of rules that are -applied to each model. A rule can pass or fail when it is run. Based on the -severity of the rule, models are scored with the weighted average of the rules -results. Note that `dbt-score` comes bundled with a +In order to lint and score models or sources, `dbt-score` uses a set of rules +that are applied to each item. A rule can pass or fail when it is run. Based on +the severity of the rule, items are scored with the weighted average of the +rules results. Note that `dbt-score` comes bundled with a [set of default rules](rules/generic.md). On top of the generic rules, it's possible to add your own rules. Two ways exist @@ -21,7 +21,7 @@ The `@rule` decorator can be used to easily create a new rule: from dbt_score import Model, rule, RuleViolation @rule -def has_description(model: Model) -> RuleViolation | None: +def model_has_description(model: Model) -> RuleViolation | None: """A model should have a description.""" if not model.description: return RuleViolation(message="Model lacks a description.") @@ -31,6 +31,21 @@ The name of the function is the name of the rule and the docstring of the function is its description. Therefore, it is important to use a self-explanatory name for the function and document it well. +The type annotation for the rule's argument dictates whether the rule should be +applied to dbt models or sources. + +Here is the same example rule, applied to sources: + +```python +from dbt_score import rule, RuleViolation, Source + +@rule +def source_has_description(source: Source) -> RuleViolation | None: + """A source should have a description.""" + if not source.description: + return RuleViolation(message="Source lacks a description.") +``` + The severity of a rule can be set using the `severity` argument: ```python @@ -45,15 +60,23 @@ For more advanced use cases, a rule can be created by inheriting from the `Rule` class: ```python -from dbt_score import Model, Rule, RuleViolation +from dbt_score import Model, Rule, RuleViolation, Source -class HasDescription(Rule): +class ModelHasDescription(Rule): description = "A model should have a description." def evaluate(self, model: Model) -> RuleViolation | None: """Evaluate the rule.""" if not model.description: return RuleViolation(message="Model lacks a description.") + +class SourceHasDescription(Rule): + description = "A source should have a description." + + def evaluate(self, source: Source) -> RuleViolation | None: + """Evaluate the rule.""" + if not source.description: + return RuleViolation(message="Source lacks a description.") ``` ### Rules location @@ -91,30 +114,48 @@ def sql_has_reasonable_number_of_lines(model: Model, max_lines: int = 200) -> Ru ) ``` -### Filtering models +### Filtering rules -Custom and standard rules can be configured to have model filters. Filters allow -models to be ignored by one or multiple rules. +Custom and standard rules can be configured to have filters. Filters allow +models or sources to be ignored by one or multiple rules if the item doesn't +satisfy the filter criteria. Filters are created using the same discovery mechanism and interface as custom rules, except they do not accept parameters. Similar to Python's built-in -`filter` function, when the filter evaluation returns `True` the model should be +`filter` function, when the filter evaluation returns `True` the item should be evaluated, otherwise it should be ignored. ```python -from dbt_score import ModelFilter, model_filter +from dbt_score import Model, RuleFilter, rule_filter -@model_filter +@rule_filter def only_schema_x(model: Model) -> bool: """Only applies a rule to schema X.""" return model.schema.lower() == 'x' -class SkipSchemaY(ModelFilter): +class SkipSchemaY(RuleFilter): description = "Applies a rule to every schema but Y." def evaluate(self, model: Model) -> bool: return model.schema.lower() != 'y' ``` +Filters also rely on type-annotations to dictate whether they apply to models or +sources: + +```python +from dbt_score import RuleFilter, rule_filter, Source + +@rule_filter +def only_from_source_a(source: Source) -> bool: + """Only applies a rule to source tables from source X.""" + return source.source_name.lower() == 'a' + +class SkipSourceDatabaseB(RuleFilter): + description = "Applies a rule to every source except Database B." + def evaluate(self, source: Source) -> bool: + return source.database.lower() != 'b' +``` + Similar to setting a rule severity, standard rules can have filters set in the [configuration file](configuration.md/#tooldbt-scorerulesrule_namespacerule_name), while custom rules accept the configuration file or a decorator parameter. @@ -123,7 +164,7 @@ while custom rules accept the configuration file or a decorator parameter. from dbt_score import Model, rule, RuleViolation from my_project import only_schema_x -@rule(model_filters={only_schema_x()}) +@rule(rule_filters={only_schema_x()}) def models_in_x_follow_naming_standard(model: Model) -> RuleViolation | None: """Models in schema X must follow the naming standard.""" if some_regex_fails(model.name): diff --git a/docs/get_started.md b/docs/get_started.md index 9c00b9f..d96d585 100644 --- a/docs/get_started.md +++ b/docs/get_started.md @@ -40,8 +40,8 @@ It's also possible to automatically run `dbt parse`, to generate the dbt-score lint --run-dbt-parse ``` -To lint only a selection of models, the argument `--select` can be used. It -accepts any +To lint only a selection of models or sources, the argument `--select` can be +used. It accepts any [dbt node selection syntax](https://docs.getdbt.com/reference/node-selection/syntax): ```shell diff --git a/docs/index.md b/docs/index.md index e74efc0..c100708 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,8 +2,9 @@ `dbt-score` is a linter for [dbt](https://www.getdbt.com/) metadata. -dbt allows data practitioners to organize their data in to _models_. Those -models have metadata associated with them: documentation, tests, types, etc. +dbt allows data practitioners to organize their data in to _models_ and +_sources_. Those models and sources have metadata associated with them: +documentation, tests, types, etc. `dbt-score` allows to lint and score this metadata, in order to enforce (or encourage) good practices. @@ -12,7 +13,7 @@ encourage) good practices. ``` > dbt-score lint -🥇 customers (score: 10.0) +🥇 M: customers (score: 10.0) OK dbt_score.rules.generic.has_description OK dbt_score.rules.generic.has_owner OK dbt_score.rules.generic.sql_has_reasonable_number_of_lines @@ -25,17 +26,17 @@ score. ## Philosophy -dbt models are often used as metadata containers: either in YAML files or -through the use of `{{ config() }}` blocks, they are associated with a lot of +dbt models/sources are often used as metadata containers: either in YAML files +or through the use of `{{ config() }}` blocks, they are associated with a lot of information. At scale, it becomes tedious to enforce good practices in large -data teams dealing with many models. +data teams dealing with many models/sources. To that end, `dbt-score` has 2 main features: -- It runs rules on models, and displays rule violations. Those can be used in - interactive environments or in CI. -- Using those run results, it scores models, as to give them a measure of their - maturity. This score can help gamify model metadata improvements, and be +- It runs rules on dbt models and sources, and displays any rule violations. + These can be used in interactive environments or in CI. +- Using those run results, it scores items, to ascribe them a measure of their + maturity. This score can help gamify metadata improvements/coverage, and be reflected in data catalogs. `dbt-score` aims to: diff --git a/docs/programmatic_invocations.md b/docs/programmatic_invocations.md index 8e2637c..5b95025 100644 --- a/docs/programmatic_invocations.md +++ b/docs/programmatic_invocations.md @@ -61,9 +61,9 @@ When `dbt-score` terminates, it exists with one of the following exit codes: project being linted either doesn't raise any warning, or the warnings are small enough to be above the thresholds. This generally means "successful linting". -- `1` in case of linting errors. This is the unhappy case: some models in the - project raise enough warnings to have a score below the defined thresholds. - This generally means "linting doesn't pass". +- `1` in case of linting errors. This is the unhappy case: some models or + sources in the project raise enough warnings to have a score below the defined + thresholds. This generally means "linting doesn't pass". - `2` in case of an unexpected error. This happens for example if something is misconfigured (for example a faulty dbt project), or the wrong parameters are given to the CLI. This generally means "setup needs to be fixed". diff --git a/pyproject.toml b/pyproject.toml index 1ce548c..a2f8caf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "pdm.backend" name = "dbt-score" dynamic = ["version"] -description = "Linter for dbt model metadata." +description = "Linter for dbt metadata." authors = [ {name = "Picnic Analyst Development Platform", email = "analyst-development-platform@teampicnic.com"} ] @@ -101,6 +101,7 @@ max-args = 9 [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = [ "PLR2004", # Magic value comparisons + "PLR0913", # Too many args in func def ] ### Coverage ### @@ -114,3 +115,7 @@ source = [ [tool.coverage.report] show_missing = true fail_under = 80 +exclude_also = [ + "@overload" +] + diff --git a/src/dbt_score/__init__.py b/src/dbt_score/__init__.py index 3f4d3b2..2cbe499 100644 --- a/src/dbt_score/__init__.py +++ b/src/dbt_score/__init__.py @@ -1,15 +1,16 @@ """Init dbt_score package.""" -from dbt_score.model_filter import ModelFilter, model_filter -from dbt_score.models import Model +from dbt_score.models import Model, Source from dbt_score.rule import Rule, RuleViolation, Severity, rule +from dbt_score.rule_filter import RuleFilter, rule_filter __all__ = [ "Model", - "ModelFilter", + "Source", + "RuleFilter", "Rule", "RuleViolation", "Severity", - "model_filter", + "rule_filter", "rule", ] diff --git a/src/dbt_score/cli.py b/src/dbt_score/cli.py index 613163e..9585d95 100644 --- a/src/dbt_score/cli.py +++ b/src/dbt_score/cli.py @@ -81,15 +81,15 @@ def cli() -> None: default=False, ) @click.option( - "--fail_project_under", + "--fail-project-under", help="Fail if the project score is under this value.", type=float, is_flag=False, default=None, ) @click.option( - "--fail_any_model_under", - help="Fail if any model is under this value.", + "--fail-any-item-under", + help="Fail if any evaluable item is under this value.", type=float, is_flag=False, default=None, @@ -104,9 +104,9 @@ def lint( manifest: Path, run_dbt_parse: bool, fail_project_under: float, - fail_any_model_under: float, + fail_any_item_under: float, ) -> None: - """Lint dbt models metadata.""" + """Lint dbt metadata.""" manifest_provided = ( click.get_current_context().get_parameter_source("manifest") != ParameterSource.DEFAULT @@ -122,8 +122,8 @@ def lint( config.overload({"disabled_rules": disabled_rule}) if fail_project_under: config.overload({"fail_project_under": fail_project_under}) - if fail_any_model_under: - config.overload({"fail_any_model_under": fail_any_model_under}) + if fail_any_item_under: + config.overload({"fail_any_item_under": fail_any_item_under}) try: if run_dbt_parse: @@ -148,7 +148,7 @@ def lint( ctx.exit(2) if ( - any(x.value < config.fail_any_model_under for x in evaluation.scores.values()) + any(x.value < config.fail_any_item_under for x in evaluation.scores.values()) or evaluation.project_score.value < config.fail_project_under ): ctx.exit(1) diff --git a/src/dbt_score/config.py b/src/dbt_score/config.py index 4a4ddf5..a3e0b2a 100644 --- a/src/dbt_score/config.py +++ b/src/dbt_score/config.py @@ -56,7 +56,7 @@ class Config: "disabled_rules", "inject_cwd_in_python_path", "fail_project_under", - "fail_any_model_under", + "fail_any_item_under", ] _rules_section: Final[str] = "rules" _badges_section: Final[str] = "badges" @@ -70,7 +70,7 @@ def __init__(self) -> None: self.config_file: Path | None = None self.badge_config: BadgeConfig = BadgeConfig() self.fail_project_under: float = 5.0 - self.fail_any_model_under: float = 5.0 + self.fail_any_item_under: float = 5.0 def set_option(self, option: str, value: Any) -> None: """Set an option in the config.""" diff --git a/src/dbt_score/dbt_utils.py b/src/dbt_score/dbt_utils.py index f6ffc5b..fa89938 100644 --- a/src/dbt_score/dbt_utils.py +++ b/src/dbt_score/dbt_utils.py @@ -69,7 +69,7 @@ def dbt_parse() -> "dbtRunnerResult": @dbt_required def dbt_ls(select: Iterable[str] | None) -> Iterable[str]: """Run dbt ls.""" - cmd = ["ls", "--resource-type", "model", "--output", "name"] + cmd = ["ls", "--resource-types", "model", "source", "--output", "name"] if select: cmd += ["--select", *select] diff --git a/src/dbt_score/evaluation.py b/src/dbt_score/evaluation.py index c583d06..bb29f03 100644 --- a/src/dbt_score/evaluation.py +++ b/src/dbt_score/evaluation.py @@ -2,19 +2,20 @@ from __future__ import annotations -from typing import Type +from itertools import chain +from typing import Type, cast from dbt_score.formatters import Formatter -from dbt_score.models import ManifestLoader, Model +from dbt_score.models import Evaluable, ManifestLoader from dbt_score.rule import Rule, RuleViolation from dbt_score.rule_registry import RuleRegistry from dbt_score.scoring import Score, Scorer -# The results of a given model are stored in a dictionary, mapping rules to either: +# The results of a given evaluable are stored in a dictionary, mapping rules to either: # - None if there was no issue # - A RuleViolation if a linting error was found # - An Exception if the rule failed to run -ModelResultsType = dict[Type[Rule], None | RuleViolation | Exception] +EvaluableResultsType = dict[Type[Rule], None | RuleViolation | Exception] class Evaluation: @@ -31,7 +32,7 @@ def __init__( Args: rule_registry: A rule registry to access rules. - manifest_loader: A manifest loader to access model metadata. + manifest_loader: A manifest loader to access dbt metadata. formatter: A formatter to display results. scorer: A scorer to compute scores. """ @@ -40,11 +41,11 @@ def __init__( self._formatter = formatter self._scorer = scorer - # For each model, its results - self.results: dict[Model, ModelResultsType] = {} + # For each evaluable, its results + self.results: dict[Evaluable, EvaluableResultsType] = {} - # For each model, its computed score - self.scores: dict[Model, Score] = {} + # For each evaluable, its computed score + self.scores: dict[Evaluable, Score] = {} # The aggregated project score self.project_score: Score @@ -53,26 +54,33 @@ def evaluate(self) -> None: """Evaluate all rules.""" rules = self._rule_registry.rules.values() - for model in self._manifest_loader.models: - self.results[model] = {} + for evaluable in chain( + self._manifest_loader.models, self._manifest_loader.sources + ): + # type inference on elements from `chain` is wonky + # and resolves to superclass HasColumnsMixin + evaluable = cast(Evaluable, evaluable) + self.results[evaluable] = {} for rule in rules: try: - if rule.should_evaluate(model): # Consider model filter(s). - result = rule.evaluate(model, **rule.config) - self.results[model][rule.__class__] = result + if rule.should_evaluate(evaluable): + result = rule.evaluate(evaluable, **rule.config) + self.results[evaluable][rule.__class__] = result except Exception as e: - self.results[model][rule.__class__] = e + self.results[evaluable][rule.__class__] = e - self.scores[model] = self._scorer.score_model(self.results[model]) - self._formatter.model_evaluated( - model, self.results[model], self.scores[model] + self.scores[evaluable] = self._scorer.score_evaluable( + self.results[evaluable] + ) + self._formatter.evaluable_evaluated( + evaluable, self.results[evaluable], self.scores[evaluable] ) # Compute score for project - self.project_score = self._scorer.score_aggregate_models( + self.project_score = self._scorer.score_aggregate_evaluables( list(self.scores.values()) ) # Add null check before calling project_evaluated - if self._manifest_loader.models: + if self._manifest_loader.models or self._manifest_loader.sources: self._formatter.project_evaluated(self.project_score) diff --git a/src/dbt_score/formatters/__init__.py b/src/dbt_score/formatters/__init__.py index ff37429..f3a7aa0 100644 --- a/src/dbt_score/formatters/__init__.py +++ b/src/dbt_score/formatters/__init__.py @@ -9,8 +9,8 @@ from dbt_score.scoring import Score if typing.TYPE_CHECKING: - from dbt_score.evaluation import ModelResultsType -from dbt_score.models import ManifestLoader, Model + from dbt_score.evaluation import EvaluableResultsType +from dbt_score.models import Evaluable, ManifestLoader class Formatter(ABC): @@ -22,10 +22,10 @@ def __init__(self, manifest_loader: ManifestLoader, config: Config): self._config = config @abstractmethod - def model_evaluated( - self, model: Model, results: ModelResultsType, score: Score + def evaluable_evaluated( + self, evaluable: Evaluable, results: EvaluableResultsType, score: Score ) -> None: - """Callback when a model has been evaluated.""" + """Callback when an evaluable item has been evaluated.""" raise NotImplementedError @abstractmethod diff --git a/src/dbt_score/formatters/ascii_formatter.py b/src/dbt_score/formatters/ascii_formatter.py index 4035dd3..61cda3b 100644 --- a/src/dbt_score/formatters/ascii_formatter.py +++ b/src/dbt_score/formatters/ascii_formatter.py @@ -1,9 +1,9 @@ """ASCII formatter.""" -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters import Formatter -from dbt_score.models import Model +from dbt_score.models import Evaluable from dbt_score.scoring import Score, Scorer # ruff: noqa: E501 [line-too-long] @@ -66,10 +66,10 @@ class ASCIIFormatter(Formatter): """Formatter for ASCII medals in the terminal.""" - def model_evaluated( - self, model: Model, results: ModelResultsType, score: Score + def evaluable_evaluated( + self, evaluable: Evaluable, results: EvaluableResultsType, score: Score ) -> None: - """Callback when a model has been evaluated.""" + """Callback when an evaluable item has been evaluated.""" pass def project_evaluated(self, score: Score) -> None: diff --git a/src/dbt_score/formatters/human_readable_formatter.py b/src/dbt_score/formatters/human_readable_formatter.py index ba49a53..ed2db7a 100644 --- a/src/dbt_score/formatters/human_readable_formatter.py +++ b/src/dbt_score/formatters/human_readable_formatter.py @@ -2,9 +2,9 @@ from typing import Any -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters import Formatter -from dbt_score.models import Model +from dbt_score.models import Evaluable, Model, Source from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -20,20 +20,39 @@ class HumanReadableFormatter(Formatter): def __init__(self, *args: Any, **kwargs: Any): """Instantiate formatter.""" super().__init__(*args, **kwargs) - self._failed_models: list[tuple[Model, Score]] = [] + self._failed_evaluables: list[tuple[Evaluable, Score]] = [] @staticmethod def bold(text: str) -> str: """Return text in bold.""" return f"\033[1m{text}\033[0m" - def model_evaluated( - self, model: Model, results: ModelResultsType, score: Score + @staticmethod + def pretty_name(evaluable: Evaluable) -> str: + """Return the pretty name for an evaluable.""" + match evaluable: + case Model(): + return evaluable.name + case Source(): + return evaluable.selector_name + case _: + raise NotImplementedError + + def evaluable_evaluated( + self, evaluable: Evaluable, results: EvaluableResultsType, score: Score ) -> None: - """Callback when a model has been evaluated.""" - if score.value < self._config.fail_any_model_under: - self._failed_models.append((model, score)) - print(f"{score.badge} {self.bold(model.name)} (score: {score.rounded_value!s})") + """Callback when an evaluable item has been evaluated.""" + if score.value < self._config.fail_any_item_under: + self._failed_evaluables.append((evaluable, score)) + + resource_type = type(evaluable).__name__ + name_formatted = f"{resource_type[0]}: {self.pretty_name(evaluable)}" + header = ( + f"{score.badge} " + f"{self.bold(name_formatted)} (score: {score.rounded_value!s})" + ) + + print(header) for rule, result in results.items(): if result is None: print(f"{self.indent}{self.label_ok} {rule.source()}") @@ -50,14 +69,18 @@ def project_evaluated(self, score: Score) -> None: """Callback when a project has been evaluated.""" print(f"Project score: {self.bold(str(score.rounded_value))} {score.badge}") - if len(self._failed_models) > 0: + if len(self._failed_evaluables) > 0: print() print( - f"Error: model score too low, fail_any_model_under = " - f"{self._config.fail_any_model_under}" + f"Error: evaluable score too low, fail_any_item_under = " + f"{self._config.fail_any_item_under}" ) - for model, model_score in self._failed_models: - print(f"Model {model.name} scored {model_score.value}") + for evaluable, evaluable_score in self._failed_evaluables: + resource_type = type(evaluable) + print( + f"{resource_type.__name__} " + f"{self.pretty_name(evaluable)} scored {evaluable_score.value}" + ) elif score.value < self._config.fail_project_under: print() diff --git a/src/dbt_score/formatters/json_formatter.py b/src/dbt_score/formatters/json_formatter.py index 29f1fde..263af49 100644 --- a/src/dbt_score/formatters/json_formatter.py +++ b/src/dbt_score/formatters/json_formatter.py @@ -4,7 +4,7 @@ ```json { - "models": { + "evaluables": { "model_foo": { "score": 5.0, "badge": "🥈", @@ -47,9 +47,9 @@ import json from typing import Any -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters import Formatter -from dbt_score.models import Model +from dbt_score.models import Evaluable from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -60,35 +60,35 @@ class JSONFormatter(Formatter): def __init__(self, *args: Any, **kwargs: Any): """Instantiate formatter.""" super().__init__(*args, **kwargs) - self._model_results: dict[str, dict[str, Any]] = {} + self.evaluable_results: dict[str, dict[str, Any]] = {} self._project_results: dict[str, Any] - def model_evaluated( - self, model: Model, results: ModelResultsType, score: Score + def evaluable_evaluated( + self, evaluable: Evaluable, results: EvaluableResultsType, score: Score ) -> None: - """Callback when a model has been evaluated.""" - self._model_results[model.name] = { + """Callback when an evaluable item has been evaluated.""" + self.evaluable_results[evaluable.name] = { "score": score.value, "badge": score.badge, - "pass": score.value >= self._config.fail_any_model_under, + "pass": score.value >= self._config.fail_any_item_under, "results": {}, } for rule, result in results.items(): severity = rule.severity.name.lower() if result is None: - self._model_results[model.name]["results"][rule.source()] = { + self.evaluable_results[evaluable.name]["results"][rule.source()] = { "result": "OK", "severity": severity, "message": None, } elif isinstance(result, RuleViolation): - self._model_results[model.name]["results"][rule.source()] = { + self.evaluable_results[evaluable.name]["results"][rule.source()] = { "result": "WARN", "severity": severity, "message": result.message, } else: - self._model_results[model.name]["results"][rule.source()] = { + self.evaluable_results[evaluable.name]["results"][rule.source()] = { "result": "ERR", "severity": severity, "message": str(result), @@ -102,7 +102,7 @@ def project_evaluated(self, score: Score) -> None: "pass": score.value >= self._config.fail_project_under, } document = { - "models": self._model_results, + "evaluables": self.evaluable_results, "project": self._project_results, } print(json.dumps(document, indent=2, ensure_ascii=False)) diff --git a/src/dbt_score/formatters/manifest_formatter.py b/src/dbt_score/formatters/manifest_formatter.py index a1914cb..0b49bd7 100644 --- a/src/dbt_score/formatters/manifest_formatter.py +++ b/src/dbt_score/formatters/manifest_formatter.py @@ -4,9 +4,9 @@ import json from typing import Any -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters import Formatter -from dbt_score.models import Model +from dbt_score.models import Evaluable from dbt_score.scoring import Score @@ -15,19 +15,25 @@ class ManifestFormatter(Formatter): def __init__(self, *args: Any, **kwargs: Any) -> None: """Instantiate a manifest formatter.""" - self._model_scores: dict[str, Score] = {} + self._evaluable_scores: dict[str, Score] = {} super().__init__(*args, **kwargs) - def model_evaluated( - self, model: Model, results: ModelResultsType, score: Score + def evaluable_evaluated( + self, evaluable: Evaluable, results: EvaluableResultsType, score: Score ) -> None: - """Callback when a model has been evaluated.""" - self._model_scores[model.unique_id] = score + """Callback when an evaluable item has been evaluated.""" + self._evaluable_scores[evaluable.unique_id] = score def project_evaluated(self, score: Score) -> None: """Callback when a project has been evaluated.""" manifest = copy.copy(self._manifest_loader.raw_manifest) - for model_id, model_score in self._model_scores.items(): - manifest["nodes"][model_id]["meta"]["score"] = model_score.value - manifest["nodes"][model_id]["meta"]["badge"] = model_score.badge + for evaluable_id, evaluable_score in self._evaluable_scores.items(): + if evaluable_id.startswith("model"): + model_manifest = manifest["nodes"][evaluable_id] + model_manifest["meta"]["score"] = evaluable_score.value + model_manifest["meta"]["badge"] = evaluable_score.badge + if evaluable_id.startswith("source"): + source_manifest = manifest["sources"][evaluable_id] + source_manifest["meta"]["score"] = evaluable_score.value + source_manifest["meta"]["badge"] = evaluable_score.badge print(json.dumps(manifest, indent=2)) diff --git a/src/dbt_score/lint.py b/src/dbt_score/lint.py index 53d09f0..9a4dfcb 100644 --- a/src/dbt_score/lint.py +++ b/src/dbt_score/lint.py @@ -1,4 +1,4 @@ -"""Lint dbt models metadata.""" +"""Lint dbt metadata.""" from pathlib import Path from typing import Iterable, Literal diff --git a/src/dbt_score/model_filter.py b/src/dbt_score/model_filter.py deleted file mode 100644 index 102a44c..0000000 --- a/src/dbt_score/model_filter.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Model filtering to choose when to apply specific rules.""" - -from typing import Any, Callable, Type, TypeAlias, overload - -from dbt_score.models import Model - -FilterEvaluationType: TypeAlias = Callable[[Model], bool] - - -class ModelFilter: - """The Filter base class.""" - - description: str - - def __init__(self) -> None: - """Initialize the filter.""" - pass - - def __init_subclass__(cls, **kwargs) -> None: # type: ignore - """Initializes the subclass.""" - super().__init_subclass__(**kwargs) - if not hasattr(cls, "description"): - raise AttributeError("Subclass must define class attribute `description`.") - - def evaluate(self, model: Model) -> bool: - """Evaluates the filter.""" - raise NotImplementedError("Subclass must implement method `evaluate`.") - - @classmethod - def source(cls) -> str: - """Return the source of the filter, i.e. a fully qualified name.""" - return f"{cls.__module__}.{cls.__name__}" - - def __hash__(self) -> int: - """Compute a unique hash for a filter.""" - return hash(self.source()) - - -# Use @overload to have proper typing for both @model_filter and @model_filter(...) -# https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories - - -@overload -def model_filter(__func: FilterEvaluationType) -> Type[ModelFilter]: - ... - - -@overload -def model_filter( - *, - description: str | None = None, -) -> Callable[[FilterEvaluationType], Type[ModelFilter]]: - ... - - -def model_filter( - __func: FilterEvaluationType | None = None, - *, - description: str | None = None, -) -> Type[ModelFilter] | Callable[[FilterEvaluationType], Type[ModelFilter]]: - """Model-filter decorator. - - The model-filter decorator creates a filter class (subclass of ModelFilter) - and returns it. - - Using arguments or not are both supported: - - ``@model_filter`` - - ``@model_filter(description="...")`` - - Args: - __func: The filter evaluation function being decorated. - description: The description of the filter. - """ - - def decorator_filter( - func: FilterEvaluationType, - ) -> Type[ModelFilter]: - """Decorator function.""" - if func.__doc__ is None and description is None: - raise AttributeError( - "ModelFilter must define `description` or `func.__doc__`." - ) - - # Get description parameter, otherwise use the docstring - filter_description = description or ( - func.__doc__.split("\n")[0] if func.__doc__ else None - ) - - def wrapped_func(self: ModelFilter, *args: Any, **kwargs: Any) -> bool: - """Wrap func to add `self`.""" - return func(*args, **kwargs) - - # Create the filter class inheriting from ModelFilter - filter_class = type( - func.__name__, - (ModelFilter,), - { - "description": filter_description, - "evaluate": wrapped_func, - # Save provided evaluate function - "_orig_evaluate": func, - # Forward origin of the decorated function - "__qualname__": func.__qualname__, # https://peps.python.org/pep-3155/ - "__module__": func.__module__, - }, - ) - - return filter_class - - if __func is not None: - # The syntax @model_filter is used - return decorator_filter(__func) - else: - # The syntax @model_filter(...) is used - return decorator_filter diff --git a/src/dbt_score/models.py b/src/dbt_score/models.py index 910068c..79a76b1 100644 --- a/src/dbt_score/models.py +++ b/src/dbt_score/models.py @@ -6,7 +6,7 @@ from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Iterable +from typing import Any, Iterable, Literal, TypeAlias from dbt_score.dbt_utils import dbt_ls @@ -42,7 +42,7 @@ def from_raw_values(cls, raw_values: dict[str, Any]) -> "Constraint": @dataclass class Test: - """Test for a column or model. + """Test for a column, model or source. Attributes: name: The name of the test. @@ -72,7 +72,7 @@ def from_node(cls, test_node: dict[str, Any]) -> "Test": @dataclass class Column: - """Represents a column in a model. + """Represents a column. Attributes: name: The name of the column. @@ -117,8 +117,39 @@ def from_node_values( ) +class HasColumnsMixin: + """Common methods for resource types that have columns.""" + + columns: list[Column] + + def get_column(self, column_name: str) -> Column | None: + """Get a column by name.""" + for column in self.columns: + if column.name == column_name: + return column + + return None + + @staticmethod + def _get_columns( + node_values: dict[str, Any], test_values: list[dict[str, Any]] + ) -> list[Column]: + """Get columns from a node and its tests in the manifest.""" + return [ + Column.from_node_values( + values, + [ + test + for test in test_values + if test["test_metadata"]["kwargs"].get("column_name") == name + ], + ) + for name, values in node_values.get("columns", {}).items() + ] + + @dataclass -class Model: +class Model(HasColumnsMixin): """Represents a dbt model. Attributes: @@ -167,31 +198,6 @@ class Model: _raw_values: dict[str, Any] = field(default_factory=dict) _raw_test_values: list[dict[str, Any]] = field(default_factory=list) - def get_column(self, column_name: str) -> Column | None: - """Get a column by name.""" - for column in self.columns: - if column.name == column_name: - return column - - return None - - @staticmethod - def _get_columns( - node_values: dict[str, Any], test_values: list[dict[str, Any]] - ) -> list[Column]: - """Get columns from a node and its tests in the manifest.""" - return [ - Column.from_node_values( - values, - [ - test - for test in test_values - if test["test_metadata"]["kwargs"].get("column_name") == name - ], - ) - for name, values in node_values.get("columns", {}).items() - ] - @classmethod def from_node( cls, node_values: dict[str, Any], test_values: list[dict[str, Any]] @@ -230,8 +236,145 @@ def __hash__(self) -> int: return hash(self.unique_id) +@dataclass +class Duration: + """Represents a duration used in SourceFreshness. + + This is referred to as `Time` in the dbt JSONSchema. + + Attributes: + count: a positive integer + period: "minute" | "hour" | "day" + """ + + count: int | None = None + period: Literal["minute", "hour", "day"] | None = None + + +@dataclass +class SourceFreshness: + """Represents a source freshness configuration. + + This is referred to as `FreshnessThreshold` in the dbt JSONSchema. + + Attributes: + warn_after: The threshold after which the dbt source freshness check should + soft-fail with a warning. + error_after: The threshold after which the dbt source freshness check should + fail. + filter: An optional filter to apply to the input data before running + source freshness check. + """ + + warn_after: Duration + error_after: Duration + filter: str | None = None + + +@dataclass +class Source(HasColumnsMixin): + """Represents a dbt source table. + + Attributes: + unique_id: The id of the source table, + e.g. 'source.package.source_name.source_table_name'. + name: The alias of the source table. + description: The full description of the source table. + source_name: The source namespace. + source_description: The description for the source namespace. + original_file_path: The yml path to the source definition. + config: The config of the source definition. + meta: Any meta-attributes on the source table. + source_meta: Any meta-attribuets on the source namespace. + columns: The list of columns for the source table. + package_name: The dbt package name for the source table. + database: The database name of the source table. + schema: The schema name of the source table. + identifier: The actual source table name, i.e. not an alias. + loader: The tool used to load the source table into the warehouse. + freshness: A set of time thresholds after which data may be considered stale. + patch_path: The yml path of the source definition. + tags: The list of tags attached to the source table. + tests: The list of tests attached to the source table. + _raw_values: The raw values of the source definition in the manifest. + _raw_test_values: The raw test values of the source definition in the manifest. + """ + + unique_id: str + name: str + description: str + source_name: str + source_description: str + original_file_path: str + config: dict[str, Any] + meta: dict[str, Any] + source_meta: dict[str, Any] + columns: list[Column] + package_name: str + database: str + schema: str + identifier: str + loader: str + freshness: SourceFreshness + patch_path: str | None = None + tags: list[str] = field(default_factory=list) + tests: list[Test] = field(default_factory=list) + _raw_values: dict[str, Any] = field(default_factory=dict) + _raw_test_values: list[dict[str, Any]] = field(default_factory=list) + + @property + def selector_name(self) -> str: + """Returns the name used by the dbt `source` method selector. + + Note: This is also the format output by `dbt ls --output name` for sources. + + https://docs.getdbt.com/reference/node-selection/methods#the-source-method + """ + return f"{self.source_name}.{self.name}" + + @classmethod + def from_node( + cls, node_values: dict[str, Any], test_values: list[dict[str, Any]] + ) -> "Source": + """Create a source object from a node and it's tests in the manifest.""" + return cls( + unique_id=node_values["unique_id"], + name=node_values["name"], + description=node_values["description"], + source_name=node_values["source_name"], + source_description=node_values["source_description"], + original_file_path=node_values["original_file_path"], + config=node_values["config"], + meta=node_values["meta"], + source_meta=node_values["source_meta"], + columns=cls._get_columns(node_values, test_values), + package_name=node_values["package_name"], + database=node_values["database"], + schema=node_values["schema"], + identifier=node_values["identifier"], + loader=node_values["loader"], + freshness=node_values["freshness"], + patch_path=node_values["patch_path"], + tags=node_values["tags"], + tests=[ + Test.from_node(test) + for test in test_values + if not test["test_metadata"]["kwargs"].get("column_name") + ], + _raw_values=node_values, + _raw_test_values=test_values, + ) + + def __hash__(self) -> int: + """Compute a unique hash for a source.""" + return hash(self.unique_id) + + +Evaluable: TypeAlias = Model | Source + + class ManifestLoader: - """Load the models and tests from the manifest.""" + """Load the models, sources and tests from the manifest.""" def __init__(self, file_path: Path, select: Iterable[str] | None = None): """Initialize the ManifestLoader. @@ -247,17 +390,25 @@ def __init__(self, file_path: Path, select: Iterable[str] | None = None): for node_id, node_values in self.raw_manifest.get("nodes", {}).items() if node_values["package_name"] == self.project_name } + self.raw_sources = { + source_id: source_values + for source_id, source_values in self.raw_manifest.get("sources", {}).items() + if source_values["package_name"] == self.project_name + } + self.models: list[Model] = [] self.tests: dict[str, list[dict[str, Any]]] = defaultdict(list) + self.sources: list[Source] = [] self._reindex_tests() self._load_models() + self._load_sources() if select: - self._select_models(select) + self._filter_evaluables(select) - if len(self.models) == 0: - logger.warning("No model found.") + if (len(self.models) + len(self.sources)) == 0: + logger.warning("Nothing to evaluate!") def _load_models(self) -> None: """Load the models from the manifest.""" @@ -266,17 +417,34 @@ def _load_models(self) -> None: model = Model.from_node(node_values, self.tests.get(node_id, [])) self.models.append(model) + def _load_sources(self) -> None: + """Load the sources from the manifest.""" + for source_id, source_values in self.raw_sources.items(): + if source_values.get("resource_type") == "source": + source = Source.from_node(source_values, self.tests.get(source_id, [])) + self.sources.append(source) + def _reindex_tests(self) -> None: - """Index tests based on their model id.""" + """Index tests based on their associated evaluable.""" for node_values in self.raw_nodes.values(): - # Only include tests that are attached to a model - if node_values.get("resource_type") == "test" and ( - attached_node := node_values.get("attached_node") - ): - self.tests[attached_node].append(node_values) - - def _select_models(self, select: Iterable[str]) -> None: - """Filter models like dbt's --select.""" + if node_values.get("resource_type") == "test": + # tests for models have a non-null value for `attached_node` + if attached_node := node_values.get("attached_node"): + self.tests[attached_node].append(node_values) + + # Tests for sources will have a null `attached_node`, + # and a non-empty list for `sources`. + # They need to be attributed to the source id + # based on the `depends_on` field. + elif node_values.get("sources") and ( + source_unique_id := next( + iter(node_values.get("depends_on", {}).get("nodes", [])), None + ) + ): + self.tests[source_unique_id].append(node_values) + + def _filter_evaluables(self, select: Iterable[str]) -> None: + """Filter evaluables like dbt's --select.""" single_model_select = re.compile(r"[a-zA-Z0-9_]+") if all(single_model_select.fullmatch(x) for x in select): @@ -287,4 +455,5 @@ def _select_models(self, select: Iterable[str]) -> None: # Use dbt's implementation of --select selected = dbt_ls(select) - self.models = [x for x in self.models if x.name in selected] + self.models = [m for m in self.models if m.name in selected] + self.sources = [s for s in self.sources if s.selector_name in selected] diff --git a/src/dbt_score/more_itertools.py b/src/dbt_score/more_itertools.py new file mode 100644 index 0000000..e1d09a5 --- /dev/null +++ b/src/dbt_score/more_itertools.py @@ -0,0 +1,50 @@ +"""Vendored utility functions from https://github.com/more-itertools/more-itertools.""" +from typing import ( + Callable, + Iterable, + Optional, + TypeVar, + overload, +) + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +@overload +def first_true( + iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ... +) -> _T | None: + ... + + +@overload +def first_true( + iterable: Iterable[_T], + default: _U, + pred: Callable[[_T], object] | None = ..., +) -> _T | _U: + ... + + +def first_true( + iterable: Iterable[_T], + default: Optional[_U] = None, + pred: Optional[Callable[[_T], object]] = None, +) -> _T | _U | None: + """Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' + + """ + return next(filter(pred, iterable), default) diff --git a/src/dbt_score/rule.py b/src/dbt_score/rule.py index 3d68fdc..e01ce55 100644 --- a/src/dbt_score/rule.py +++ b/src/dbt_score/rule.py @@ -4,10 +4,19 @@ import typing from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Iterable, Type, TypeAlias, overload +from typing import ( + Any, + Callable, + Iterable, + Type, + TypeAlias, + cast, + overload, +) -from dbt_score.model_filter import ModelFilter -from dbt_score.models import Model +from dbt_score.models import Evaluable, Model, Source +from dbt_score.more_itertools import first_true +from dbt_score.rule_filter import RuleFilter class Severity(Enum): @@ -25,7 +34,7 @@ class RuleConfig: severity: Severity | None = None config: dict[str, Any] = field(default_factory=dict) - model_filter_names: list[str] = field(default_factory=list) + rule_filter_names: list[str] = field(default_factory=list) @staticmethod def from_dict(rule_config: dict[str, Any]) -> "RuleConfig": @@ -37,13 +46,13 @@ def from_dict(rule_config: dict[str, Any]) -> "RuleConfig": else None ) filter_names = ( - config.pop("model_filter_names", None) - if "model_filter_names" in rule_config + config.pop("rule_filter_names", None) + if "rule_filter_names" in rule_config else [] ) return RuleConfig( - severity=severity, config=config, model_filter_names=filter_names + severity=severity, config=config, rule_filter_names=filter_names ) @@ -54,7 +63,9 @@ class RuleViolation: message: str | None = None -RuleEvaluationType: TypeAlias = Callable[[Model], RuleViolation | None] +ModelRuleEvaluationType: TypeAlias = Callable[[Model], RuleViolation | None] +SourceRuleEvaluationType: TypeAlias = Callable[[Source], RuleViolation | None] +RuleEvaluationType: TypeAlias = ModelRuleEvaluationType | SourceRuleEvaluationType class Rule: @@ -62,9 +73,10 @@ class Rule: description: str severity: Severity = Severity.MEDIUM - model_filter_names: list[str] - model_filters: frozenset[ModelFilter] = frozenset() + rule_filter_names: list[str] + rule_filters: frozenset[RuleFilter] = frozenset() default_config: typing.ClassVar[dict[str, Any]] = {} + resource_type: typing.ClassVar[type[Evaluable]] def __init__(self, rule_config: RuleConfig | None = None) -> None: """Initialize the rule.""" @@ -78,6 +90,40 @@ def __init_subclass__(cls, **kwargs) -> None: # type: ignore if not hasattr(cls, "description"): raise AttributeError("Subclass must define class attribute `description`.") + cls.resource_type = cls._introspect_resource_type() + + cls._validate_rule_filters() + + @classmethod + def _validate_rule_filters(cls) -> None: + for rule_filter in cls.rule_filters: + if rule_filter.resource_type != cls.resource_type: + raise TypeError( + f"Mismatched resource_type on filter " + f"{rule_filter.__class__.__name__}. " + f"Expected {cls.resource_type.__name__}, " + f"but got {rule_filter.resource_type.__name__}." + ) + + @classmethod + def _introspect_resource_type(cls) -> Type[Evaluable]: + evaluate_func = getattr(cls, "_orig_evaluate", cls.evaluate) + + sig = inspect.signature(evaluate_func) + resource_type_argument = first_true( + sig.parameters.values(), + pred=lambda arg: arg.annotation in typing.get_args(Evaluable), + ) + + if not resource_type_argument: + raise TypeError( + "Subclass must implement method `evaluate` with an " + "annotated Model or Source argument." + ) + + resource_type = cast(type[Evaluable], resource_type_argument.annotation) + return resource_type + def process_config(self, rule_config: RuleConfig) -> None: """Process the rule config.""" config = self.default_config.copy() @@ -94,19 +140,29 @@ def process_config(self, rule_config: RuleConfig) -> None: self.set_severity( rule_config.severity ) if rule_config.severity else rule_config.severity - self.model_filter_names = rule_config.model_filter_names + self.rule_filter_names = rule_config.rule_filter_names self.config = config - def evaluate(self, model: Model) -> RuleViolation | None: + def evaluate(self, evaluable: Evaluable) -> RuleViolation | None: """Evaluates the rule.""" raise NotImplementedError("Subclass must implement method `evaluate`.") @classmethod - def should_evaluate(cls, model: Model) -> bool: - """Checks if all filters in the rule allow evaluation.""" - if cls.model_filters: - return all(f.evaluate(model) for f in cls.model_filters) - return True + def should_evaluate(cls, evaluable: Evaluable) -> bool: + """Checks whether the rule should be applied against the evaluable. + + The evaluable must satisfy the following criteria: + - all filters in the rule allow evaluation + - the rule and evaluable have matching resource_types + """ + resource_types_match = cls.resource_type is type(evaluable) + + if cls.rule_filters: + return ( + all(f.evaluate(evaluable) for f in cls.rule_filters) + and resource_types_match + ) + return resource_types_match @classmethod def set_severity(cls, severity: Severity) -> None: @@ -114,9 +170,9 @@ def set_severity(cls, severity: Severity) -> None: cls.severity = severity @classmethod - def set_filters(cls, model_filters: Iterable[ModelFilter]) -> None: + def set_filters(cls, rule_filters: Iterable[RuleFilter]) -> None: """Set the filters of the rule.""" - cls.model_filters = frozenset(model_filters) + cls.rule_filters = frozenset(rule_filters) @classmethod def source(cls) -> str: @@ -133,7 +189,12 @@ def __hash__(self) -> int: @overload -def rule(__func: RuleEvaluationType) -> Type[Rule]: +def rule(__func: ModelRuleEvaluationType) -> Type[Rule]: + ... + + +@overload +def rule(__func: SourceRuleEvaluationType) -> Type[Rule]: ... @@ -142,7 +203,7 @@ def rule( *, description: str | None = None, severity: Severity = Severity.MEDIUM, - model_filters: set[ModelFilter] | None = None, + rule_filters: set[RuleFilter] | None = None, ) -> Callable[[RuleEvaluationType], Type[Rule]]: ... @@ -152,7 +213,7 @@ def rule( *, description: str | None = None, severity: Severity = Severity.MEDIUM, - model_filters: set[ModelFilter] | None = None, + rule_filters: set[RuleFilter] | None = None, ) -> Type[Rule] | Callable[[RuleEvaluationType], Type[Rule]]: """Rule decorator. @@ -166,12 +227,10 @@ def rule( __func: The rule evaluation function being decorated. description: The description of the rule. severity: The severity of the rule. - model_filters: Set of ModelFilter that filters the rule. + rule_filters: Set of RuleFilter that filters the items that the rule applies to. """ - def decorator_rule( - func: RuleEvaluationType, - ) -> Type[Rule]: + def decorator_rule(func: RuleEvaluationType) -> Type[Rule]: """Decorator function.""" if func.__doc__ is None and description is None: raise AttributeError("Rule must define `description` or `func.__doc__`.") @@ -199,7 +258,7 @@ def wrapped_func(self: Rule, *args: Any, **kwargs: Any) -> RuleViolation | None: { "description": rule_description, "severity": severity, - "model_filters": model_filters or frozenset(), + "rule_filters": rule_filters or frozenset(), "default_config": default_config, "evaluate": wrapped_func, # Save provided evaluate function diff --git a/src/dbt_score/rule_filter.py b/src/dbt_score/rule_filter.py new file mode 100644 index 0000000..c8e0e46 --- /dev/null +++ b/src/dbt_score/rule_filter.py @@ -0,0 +1,145 @@ +"""Evaluable filtering to choose when to apply specific rules.""" + +import inspect +import typing +from typing import Any, Callable, Type, TypeAlias, cast, overload + +from dbt_score.models import Evaluable, Model, Source +from dbt_score.more_itertools import first_true + +ModelFilterEvaluationType: TypeAlias = Callable[[Model], bool] +SourceFilterEvaluationType: TypeAlias = Callable[[Source], bool] +FilterEvaluationType: TypeAlias = ModelFilterEvaluationType | SourceFilterEvaluationType + + +class RuleFilter: + """The Filter base class.""" + + description: str + resource_type: typing.ClassVar[type[Evaluable]] + + def __init__(self) -> None: + """Initialize the filter.""" + pass + + def __init_subclass__(cls, **kwargs) -> None: # type: ignore + """Initializes the subclass.""" + super().__init_subclass__(**kwargs) + if not hasattr(cls, "description"): + raise AttributeError("Subclass must define class attribute `description`.") + + cls.resource_type = cls._introspect_resource_type() + + @classmethod + def _introspect_resource_type(cls) -> Type[Evaluable]: + evaluate_func = getattr(cls, "_orig_evaluate", cls.evaluate) + + sig = inspect.signature(evaluate_func) + resource_type_argument = first_true( + sig.parameters.values(), + pred=lambda arg: arg.annotation in typing.get_args(Evaluable), + ) + + if not resource_type_argument: + raise TypeError( + "Subclass must implement method `evaluate` with an " + "annotated Model or Source argument." + ) + + resource_type = cast(type[Evaluable], resource_type_argument.annotation) + return resource_type + + def evaluate(self, evaluable: Evaluable) -> bool: + """Evaluates the filter.""" + raise NotImplementedError("Subclass must implement method `evaluate`.") + + @classmethod + def source(cls) -> str: + """Return the source of the filter, i.e. a fully qualified name.""" + return f"{cls.__module__}.{cls.__name__}" + + def __hash__(self) -> int: + """Compute a unique hash for a filter.""" + return hash(self.source()) + + +# Use @overload to have proper typing for both @rule_filter and @rule_filter(...) +# https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories + + +@overload +def rule_filter(__func: ModelFilterEvaluationType) -> Type[RuleFilter]: + ... + + +@overload +def rule_filter(__func: SourceFilterEvaluationType) -> Type[RuleFilter]: + ... + + +@overload +def rule_filter( + *, + description: str | None = None, +) -> Callable[[FilterEvaluationType], Type[RuleFilter]]: + ... + + +def rule_filter( + __func: FilterEvaluationType | None = None, + *, + description: str | None = None, +) -> Type[RuleFilter] | Callable[[FilterEvaluationType], Type[RuleFilter]]: + """Rule-filter decorator. + + The rule_filter decorator creates a filter class (subclass of RuleFilter) + and returns it. + + Using arguments or not are both supported: + - ``@rule_filter`` + - ``@rule_filter(description="...")`` + + Args: + __func: The filter evaluation function being decorated. + description: The description of the filter. + """ + + def decorator_filter(func: FilterEvaluationType) -> Type[RuleFilter]: + """Decorator function.""" + if func.__doc__ is None and description is None: + raise AttributeError( + "RuleFilter must define `description` or `func.__doc__`." + ) + + # Get description parameter, otherwise use the docstring + filter_description = description or ( + func.__doc__.split("\n")[0] if func.__doc__ else None + ) + + def wrapped_func(self: RuleFilter, *args: Any, **kwargs: Any) -> bool: + """Wrap func to add `self`.""" + return func(*args, **kwargs) + + # Create the filter class inheriting from RuleFilter + filter_class = type( + func.__name__, + (RuleFilter,), + { + "description": filter_description, + "evaluate": wrapped_func, + # Save provided evaluate function + "_orig_evaluate": func, + # Forward origin of the decorated function + "__qualname__": func.__qualname__, # https://peps.python.org/pep-3155/ + "__module__": func.__module__, + }, + ) + + return filter_class + + if __func is not None: + # The syntax @rule_filter is used + return decorator_filter(__func) + else: + # The syntax @rule_filter(...) is used + return decorator_filter diff --git a/src/dbt_score/rule_registry.py b/src/dbt_score/rule_registry.py index 0e4557a..f60a87e 100644 --- a/src/dbt_score/rule_registry.py +++ b/src/dbt_score/rule_registry.py @@ -12,8 +12,8 @@ from dbt_score.config import Config from dbt_score.exceptions import DuplicatedRuleException -from dbt_score.model_filter import ModelFilter from dbt_score.rule import Rule, RuleConfig +from dbt_score.rule_filter import RuleFilter logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ def __init__(self, config: Config) -> None: """Instantiate a rule registry.""" self.config = config self._rules: dict[str, Rule] = {} - self._model_filters: dict[str, ModelFilter] = {} + self._rule_filters: dict[str, RuleFilter] = {} @property def rules(self) -> dict[str, Rule]: @@ -33,9 +33,9 @@ def rules(self) -> dict[str, Rule]: return self._rules @property - def model_filters(self) -> dict[str, ModelFilter]: + def rule_filters(self) -> dict[str, RuleFilter]: """Get all filters.""" - return self._model_filters + return self._rule_filters def _walk_packages(self, namespace_name: str) -> Iterator[str]: """Walk packages and sub-packages recursively.""" @@ -66,8 +66,8 @@ def _load(self, namespace_name: str) -> None: self._add_rule(obj) if ( type(obj) is type - and issubclass(obj, ModelFilter) - and obj is not ModelFilter + and issubclass(obj, RuleFilter) + and obj is not RuleFilter ): self._add_filter(obj) @@ -80,12 +80,12 @@ def _add_rule(self, rule: Type[Rule]) -> None: rule_config = self.config.rules_config.get(rule_name, RuleConfig()) self._rules[rule_name] = rule(rule_config=rule_config) - def _add_filter(self, model_filter: Type[ModelFilter]) -> None: + def _add_filter(self, rule_filter: Type[RuleFilter]) -> None: """Initialize and add a filter.""" - filter_name = model_filter.source() - if filter_name in self._model_filters: + filter_name = rule_filter.source() + if filter_name in self._rule_filters: raise DuplicatedRuleException(filter_name) - self._model_filters[filter_name] = model_filter() + self._rule_filters[filter_name] = rule_filter() def load_all(self) -> None: """Load all rules, core and third-party.""" @@ -103,17 +103,17 @@ def load_all(self) -> None: self._load_filters_into_rules() def _load_filters_into_rules(self) -> None: - """Loads ModelFilters into Rule objects. + """Loads RuleFilters into Rule objects. - If the config of the rule has filter names in the `model_filter_names` key, - load those filters from the rule registry into the actual `model_filters` field. + If the config of the rule has filter names in the `rule_filter_names` key, + load those filters from the rule registry into the actual `rule_filters` field. Configuration overwrites any pre-existing filters. """ for rule in self._rules.values(): - filter_names: list[str] = rule.model_filter_names or [] + filter_names: list[str] = rule.rule_filter_names or [] if len(filter_names) > 0: rule.set_filters( - model_filter - for name, model_filter in self.model_filters.items() + rule_filter + for name, rule_filter in self.rule_filters.items() if name in filter_names ) diff --git a/src/dbt_score/scoring.py b/src/dbt_score/scoring.py index 5d80306..bcf8efa 100644 --- a/src/dbt_score/scoring.py +++ b/src/dbt_score/scoring.py @@ -9,7 +9,7 @@ from dbt_score.config import Config if typing.TYPE_CHECKING: - from dbt_score.evaluation import ModelResultsType + from dbt_score.evaluation import EvaluableResultsType from dbt_score.rule import RuleViolation, Severity @@ -43,16 +43,16 @@ def __init__(self, config: Config) -> None: """Create a Scorer object.""" self._config = config - def score_model(self, model_results: ModelResultsType) -> Score: - """Compute the score of a given model.""" - rule_count = len(model_results) + def score_evaluable(self, evaluable_results: EvaluableResultsType) -> Score: + """Compute the score of a given evaluable.""" + rule_count = len(evaluable_results) if rule_count == 0: # No rule? No problem score = self.max_score elif any( rule.severity == Severity.CRITICAL and isinstance(result, RuleViolation) - for rule, result in model_results.items() + for rule, result in evaluable_results.items() ): # If there's a CRITICAL violation, the score is 0 score = self.min_score @@ -65,7 +65,7 @@ def score_model(self, model_results: ModelResultsType) -> Score: self.score_cardinality - rule.severity.value if isinstance(result, RuleViolation) # Either 0/3, 1/3 or 2/3 else self.score_cardinality # 3/3 - for rule, result in model_results.items() + for rule, result in evaluable_results.items() ] ) / (self.score_cardinality * rule_count) @@ -74,11 +74,11 @@ def score_model(self, model_results: ModelResultsType) -> Score: return Score(score, self._badge(score)) - def score_aggregate_models(self, scores: list[Score]) -> Score: - """Compute the score of a list of models.""" + def score_aggregate_evaluables(self, scores: list[Score]) -> Score: + """Compute the score of a list of evaluables.""" actual_scores = [s.value for s in scores] if 0.0 in actual_scores: - # Any model with a CRITICAL violation makes the project score 0 + # Any evaluable with a CRITICAL violation makes the project score 0 score = Score(self.min_score, self._badge(self.min_score)) elif len(actual_scores) == 0: score = Score(self.max_score, self._badge(self.max_score)) diff --git a/tests/conftest.py b/tests/conftest.py index 4704d24..b984398 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,10 @@ from pathlib import Path from typing import Any, Type -from dbt_score import Model, Rule, RuleViolation, Severity, rule +from dbt_score import Model, Rule, RuleViolation, Severity, Source, rule from dbt_score.config import Config -from dbt_score.model_filter import ModelFilter, model_filter from dbt_score.models import ManifestLoader +from dbt_score.rule_filter import RuleFilter, rule_filter from pytest import fixture # Configuration @@ -73,6 +73,25 @@ def model2(raw_manifest) -> Model: return Model.from_node(raw_manifest["nodes"]["model.package.model2"], []) +# Sources + + +@fixture +def source1(raw_manifest) -> Source: + """Source 1.""" + return Source.from_node( + raw_manifest["sources"]["source.package.my_source.table1"], [] + ) + + +@fixture +def source2(raw_manifest) -> Source: + """Source 2.""" + return Source.from_node( + raw_manifest["sources"]["source.package.my_source.table2"], [] + ) + + # Multiple ways to create rules @@ -123,7 +142,7 @@ class ExampleRule(Rule): description = "Description of the rule." - def evaluate(self, model: Model) -> RuleViolation | None: + def evaluate(self, model: Model) -> RuleViolation | None: # type: ignore[override] """Evaluate model.""" if model.name == "model1": return RuleViolation(message="Model1 is a violation.") @@ -131,6 +150,61 @@ def evaluate(self, model: Model) -> RuleViolation | None: return ExampleRule +@fixture +def decorator_rule_source() -> Type[Rule]: + """An example rule created with the rule decorator.""" + + @rule() + def example_rule_source(source: Source) -> RuleViolation | None: + """Description of the rule.""" + if source.name == "table1": + return RuleViolation(message="Source1 is a violation.") + + return example_rule_source + + +@fixture +def decorator_rule_no_parens_source() -> Type[Rule]: + """An example rule created with the rule decorator without parentheses.""" + + @rule + def example_rule_source(source: Source) -> RuleViolation | None: + """Description of the rule.""" + if source.name == "table1": + return RuleViolation(message="Source1 is a violation.") + + return example_rule_source + + +@fixture +def decorator_rule_args_source() -> Type[Rule]: + """An example rule created with the rule decorator with arguments.""" + + @rule(description="Description of the rule.") + def example_rule_source(source: Source) -> RuleViolation | None: + if source.name == "table1": + return RuleViolation(message="Source1 is a violation.") + + return example_rule_source + + +@fixture +def class_rule_source() -> Type[Rule]: + """An example rule created with a class.""" + + class ExampleRuleSource(Rule): + """Example rule.""" + + description = "Description of the rule." + + def evaluate(self, source: Source) -> RuleViolation | None: # type: ignore[override] + """Evaluate source.""" + if source.name == "table1": + return RuleViolation(message="Source1 is a violation.") + + return ExampleRuleSource + + # Rules @@ -214,38 +288,76 @@ def rule_error(model: Model) -> RuleViolation | None: @fixture -def rule_with_filter() -> Type[Rule]: +def model_rule_with_filter() -> Type[Rule]: """An example rule that skips through a filter.""" - @model_filter + @rule_filter def skip_model1(model: Model) -> bool: """Skips for model1, passes for model2.""" return model.name != "model1" - @rule(model_filters={skip_model1()}) - def rule_with_filter(model: Model) -> RuleViolation | None: + @rule(rule_filters={skip_model1()}) + def model_rule_with_filter(model: Model) -> RuleViolation | None: """Rule that always fails when not filtered.""" return RuleViolation(message="I always fail.") - return rule_with_filter + return model_rule_with_filter @fixture -def class_rule_with_filter() -> Type[Rule]: +def source_rule_with_filter() -> Type[Rule]: + """An example rule that skips through a filter.""" + + @rule_filter + def skip_source1(source: Source) -> bool: + """Skips for source1, passes for source2.""" + return source.name != "table1" + + @rule(rule_filters={skip_source1()}) + def source_rule_with_filter(source: Source) -> RuleViolation | None: + """Rule that always fails when not filtered.""" + return RuleViolation(message="I always fail.") + + return source_rule_with_filter + + +@fixture +def model_class_rule_with_filter() -> Type[Rule]: """Using class definitions for filters and rules.""" - class SkipModel1(ModelFilter): + class SkipModel1(RuleFilter): description = "Filter defined by a class." - def evaluate(self, model: Model) -> bool: + def evaluate(self, model: Model) -> bool: # type: ignore[override] """Skips for model1, passes for model2.""" return model.name != "model1" - class RuleWithFilter(Rule): + class ModelRuleWithFilter(Rule): + description = "Filter defined by a class." + rule_filters = frozenset({SkipModel1()}) + + def evaluate(self, model: Model) -> RuleViolation | None: # type: ignore[override] + return RuleViolation(message="I always fail.") + + return ModelRuleWithFilter + + +@fixture +def source_class_rule_with_filter() -> Type[Rule]: + """Using class definitions for filters and rules.""" + + class SkipSource1(RuleFilter): + description = "Filter defined by a class." + + def evaluate(self, source: Source) -> bool: # type: ignore[override] + """Skips for source1, passes for source2.""" + return source.name != "table1" + + class SourceRuleWithFilter(Rule): description = "Filter defined by a class." - model_filters = frozenset({SkipModel1()}) + rule_filters = frozenset({SkipSource1()}) - def evaluate(self, model: Model) -> RuleViolation | None: + def evaluate(self, source: Source) -> RuleViolation | None: # type: ignore[override] return RuleViolation(message="I always fail.") - return RuleWithFilter + return SourceRuleWithFilter diff --git a/tests/formatters/test_ascii_formatter.py b/tests/formatters/test_ascii_formatter.py index 2ab4b41..f7c7dc9 100644 --- a/tests/formatters/test_ascii_formatter.py +++ b/tests/formatters/test_ascii_formatter.py @@ -1,7 +1,7 @@ """Unit tests for the ASCII formatter.""" -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters.ascii_formatter import ASCIIFormatter from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -18,12 +18,12 @@ def test_ascii_formatter_model( ): """Ensure the formatter doesn't write anything after model evaluation.""" formatter = ASCIIFormatter(manifest_loader=manifest_loader, config=default_config) - results: ModelResultsType = { + results: EvaluableResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(10.0, "🥇")) + formatter.evaluable_evaluated(model1, results, Score(10.0, "🥇")) stdout = capsys.readouterr().out assert stdout == "" diff --git a/tests/formatters/test_human_readable_formatter.py b/tests/formatters/test_human_readable_formatter.py index 6a3438a..0f1f90e 100644 --- a/tests/formatters/test_human_readable_formatter.py +++ b/tests/formatters/test_human_readable_formatter.py @@ -1,6 +1,7 @@ """Unit tests for the human readable formatter.""" +from textwrap import dedent -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters.human_readable_formatter import HumanReadableFormatter from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -19,22 +20,21 @@ def test_human_readable_formatter_model( formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: ModelResultsType = { + results: EvaluableResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(10.0, "🥇")) + formatter.evaluable_evaluated(model1, results, Score(10.0, "🥇")) stdout = capsys.readouterr().out - assert ( - stdout - == """🥇 \x1B[1mmodel1\x1B[0m (score: 10.0) - \x1B[1;32mOK \x1B[0m tests.conftest.rule_severity_low - \x1B[1;31mERR \x1B[0m tests.conftest.rule_severity_medium: Oh noes - \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error - -""" - ) + expected = """\ + 🥇 \x1B[1mM: model1\x1B[0m (score: 10.0) + \x1B[1;32mOK \x1B[0m tests.conftest.rule_severity_low + \x1B[1;31mERR \x1B[0m tests.conftest.rule_severity_medium: Oh noes + \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error + + """ + assert stdout == dedent(expected) def test_human_readable_formatter_project(capsys, default_config, manifest_loader): @@ -60,22 +60,22 @@ def test_human_readable_formatter_near_perfect_model_score( formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: ModelResultsType = { + results: EvaluableResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(9.99, "🥈")) + formatter.evaluable_evaluated(model1, results, Score(9.99, "🥈")) stdout = capsys.readouterr().out - assert ( - stdout - == """🥈 \x1B[1mmodel1\x1B[0m (score: 9.9) - \x1B[1;32mOK \x1B[0m tests.conftest.rule_severity_low - \x1B[1;31mERR \x1B[0m tests.conftest.rule_severity_medium: Oh noes - \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error - -""" - ) + + expected = """\ + 🥈 \x1B[1mM: model1\x1B[0m (score: 9.9) + \x1B[1;32mOK \x1B[0m tests.conftest.rule_severity_low + \x1B[1;31mERR \x1B[0m tests.conftest.rule_severity_medium: Oh noes + \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error + + """ + assert stdout == dedent(expected) def test_human_readable_formatter_near_perfect_project_score( @@ -90,35 +90,40 @@ def test_human_readable_formatter_near_perfect_project_score( assert stdout == "Project score: \x1B[1m9.9\x1B[0m 🥈\n" -def test_human_readable_formatter_low_model_score( +def test_human_readable_formatter_low_evaluable_score( capsys, default_config, manifest_loader, model1, + source1, rule_severity_critical, ): """Ensure the formatter has the correct output when a model has a low score.""" formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: ModelResultsType = { + results: EvaluableResultsType = { rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(0.0, "🚧")) + formatter.evaluable_evaluated(model1, results, Score(0.0, "🚧")) + formatter.evaluable_evaluated(source1, results, Score(0.0, "🚧")) formatter.project_evaluated(Score(0.0, "🚧")) stdout = capsys.readouterr().out - print() - assert ( - stdout - == """🚧 \x1B[1mmodel1\x1B[0m (score: 0.0) - \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error -Project score: \x1B[1m0.0\x1B[0m 🚧 + expected = """\ + 🚧 \x1B[1mM: model1\x1B[0m (score: 0.0) + \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error -Error: model score too low, fail_any_model_under = 5.0 -Model model1 scored 0.0 -""" - ) + 🚧 \x1B[1mS: my_source.table1\x1B[0m (score: 0.0) + \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error + + Project score: \x1B[1m0.0\x1B[0m 🚧 + + Error: evaluable score too low, fail_any_item_under = 5.0 + Model model1 scored 0.0 + Source my_source.table1 scored 0.0 + """ + assert stdout == dedent(expected) def test_human_readable_formatter_low_project_score( @@ -132,20 +137,19 @@ def test_human_readable_formatter_low_project_score( formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: ModelResultsType = { + results: EvaluableResultsType = { rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(10.0, "🥇")) + formatter.evaluable_evaluated(model1, results, Score(10.0, "🥇")) formatter.project_evaluated(Score(0.0, "🚧")) stdout = capsys.readouterr().out - print() - assert ( - stdout - == """🥇 \x1B[1mmodel1\x1B[0m (score: 10.0) - \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error -Project score: \x1B[1m0.0\x1B[0m 🚧 + expected = """\ + 🥇 \x1B[1mM: model1\x1B[0m (score: 10.0) + \x1B[1;33mWARN\x1B[0m (critical) tests.conftest.rule_severity_critical: Error -Error: project score too low, fail_project_under = 5.0 -""" - ) + Project score: \x1B[1m0.0\x1B[0m 🚧 + + Error: project score too low, fail_project_under = 5.0 + """ + assert stdout == dedent(expected) diff --git a/tests/formatters/test_json_formatter.py b/tests/formatters/test_json_formatter.py index 1eec69c..5c707da 100644 --- a/tests/formatters/test_json_formatter.py +++ b/tests/formatters/test_json_formatter.py @@ -23,13 +23,13 @@ def test_json_formatter( rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(10.0, "🥇")) + formatter.evaluable_evaluated(model1, results, Score(10.0, "🥇")) formatter.project_evaluated(Score(10.0, "🥇")) stdout = capsys.readouterr().out assert ( stdout == """{ - "models": { + "evaluables": { "model1": { "score": 10.0, "badge": "🥇", diff --git a/tests/formatters/test_manifest_formatter.py b/tests/formatters/test_manifest_formatter.py index baaacc8..a7da2c4 100644 --- a/tests/formatters/test_manifest_formatter.py +++ b/tests/formatters/test_manifest_formatter.py @@ -2,7 +2,7 @@ import json -from dbt_score.evaluation import ModelResultsType +from dbt_score.evaluation import EvaluableResultsType from dbt_score.formatters.manifest_formatter import ManifestFormatter from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -21,12 +21,12 @@ def test_manifest_formatter_model( formatter = ManifestFormatter( manifest_loader=manifest_loader, config=default_config ) - results: ModelResultsType = { + results: EvaluableResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - formatter.model_evaluated(model1, results, Score(10.0, "🥇")) + formatter.evaluable_evaluated(model1, results, Score(10.0, "🥇")) stdout = capsys.readouterr().out assert stdout == "" @@ -37,6 +37,8 @@ def test_manifest_formatter_project( manifest_loader, model1, model2, + source1, + source2, rule_severity_low, rule_severity_medium, rule_severity_critical, @@ -45,23 +47,43 @@ def test_manifest_formatter_project( formatter = ManifestFormatter( manifest_loader=manifest_loader, config=default_config ) - result1: ModelResultsType = { + result1: EvaluableResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - result2: ModelResultsType = { + result2: EvaluableResultsType = { rule_severity_low: None, rule_severity_medium: None, rule_severity_critical: None, } - formatter.model_evaluated(model1, result1, Score(5.0, "🚧")) - formatter.model_evaluated(model2, result2, Score(10.0, "🥇")) + formatter.evaluable_evaluated(model1, result1, Score(5.0, "🚧")) + formatter.evaluable_evaluated(model2, result2, Score(10.0, "🥇")) + formatter.evaluable_evaluated(source1, result1, Score(5.0, "🚧")) + formatter.evaluable_evaluated(source2, result2, Score(10.0, "🥇")) formatter.project_evaluated(Score(7.5, "🥉")) + stdout = capsys.readouterr().out new_manifest = json.loads(stdout) assert new_manifest["nodes"]["model.package.model1"]["meta"]["score"] == 5.0 assert new_manifest["nodes"]["model.package.model1"]["meta"]["badge"] == "🚧" assert new_manifest["nodes"]["model.package.model2"]["meta"]["score"] == 10.0 assert new_manifest["nodes"]["model.package.model2"]["meta"]["badge"] == "🥇" + + assert ( + new_manifest["sources"]["source.package.my_source.table1"]["meta"]["score"] + == 5.0 + ) + assert ( + new_manifest["sources"]["source.package.my_source.table1"]["meta"]["badge"] + == "🚧" + ) + assert ( + new_manifest["sources"]["source.package.my_source.table2"]["meta"]["score"] + == 10.0 + ) + assert ( + new_manifest["sources"]["source.package.my_source.table2"]["meta"]["badge"] + == "🥇" + ) diff --git a/tests/resources/manifest.json b/tests/resources/manifest.json index 52a90ab..af5908b 100644 --- a/tests/resources/manifest.json +++ b/tests/resources/manifest.json @@ -124,6 +124,185 @@ "test.package.test3": { "resource_type": "test", "package_name": "package" + }, + "test.package.source_test1": { + "resource_type": "test", + "package_name": "package", + "name": "source_test1", + "attached_node": null, + "sources": [["my_source", "table1"]], + "depends_on": { + "nodes": ["source.package.my_source.table1"] + }, + "test_metadata": { + "name": "type", + "kwargs": {} + } + }, + "test.package.bad_source_test1": { + "resource_type": "test", + "package_name": "package", + "name": "source_test__malformed_missing_depends_on", + "attached_node": null, + "sources": [["my_source", "table1"]], + "test_metadata": { + "name": "type", + "kwargs": {} + } + }, + "test.package.bad_source_test2": { + "resource_type": "test", + "package_name": "package", + "name": "source_test__malformed_missing_depends_on_nodes", + "attached_node": null, + "sources": [["my_source", "table1"]], + "depends_on": {}, + "test_metadata": { + "name": "type", + "kwargs": {} + } + } + }, + "sources": { + "source.package.my_source.table1": { + "database": "source_db", + "schema": "source_schema", + "name": "table1", + "resource_type": "source", + "package_name": "package", + "path": "models/sources/sources.yml", + "original_file_path": "models/sources/sources.yml", + "unique_id": "source.package.my_source.table1", + "fqn": ["package", "my_source", "table1"], + "source_name": "my_source", + "source_description": "An important source table.", + "loader": "Fivetran", + "identifier": "table1", + "quoting": {}, + "loaded_at_field": null, + "freshness": { + "warn_after": { + "count": null, + "period": null + }, + "error_after": { + "count": null, + "period": null + }, + "filter": null + }, + "external": null, + "description": "", + "columns": {}, + "meta": {}, + "source_meta": {}, + "tags": [], + "config": { + "enabled": true + }, + "patch_path": null, + "unrendered_config": {}, + "relation_name": "\"package\".\"my_source\".\"table1\"", + "created_at": 1728529440.4206917 + }, + "source.package.my_source.table2": { + "database": "source_db", + "schema": "source_schema", + "name": "table2", + "resource_type": "source", + "package_name": "package", + "path": "models/sources/sources.yml", + "original_file_path": "models/sources/sources.yml", + "unique_id": "source.package.my_source.table2", + "fqn": ["package", "my_source", "table2"], + "source_name": "my_source", + "source_description": "Another table with some columns declared.", + "loader": "Fivetran", + "identifier": "table2", + "quoting": {}, + "loaded_at_field": null, + "freshness": { + "warn_after": { + "count": null, + "period": null + }, + "error_after": { + "count": null, + "period": null + }, + "filter": null + }, + "external": null, + "description": "", + "columns": { + "a": { + "name": "column_a", + "description": "Column A.", + "data_type": "string", + "meta": {}, + "constraints": [], + "tags": [] + }, + "b": { + "name": "column_b", + "description": "Column B.", + "data_type": "integer", + "meta": {}, + "constraints": [], + "tags": [] + } + }, + "meta": {}, + "source_meta": {}, + "tags": [], + "config": { + "enabled": true + }, + "patch_path": null, + "unrendered_config": {}, + "relation_name": "\"package\".\"my_source\".\"table2\"", + "created_at": 1728529440.4206917 + }, + "source.package.my_other_source.table1": { + "database": "source_db", + "schema": "alternate_schema", + "name": "table1", + "resource_type": "source", + "package_name": "package", + "path": "models/sources/sources.yml", + "original_file_path": "models/sources/sources.yml", + "unique_id": "source.package.my_other_source.table1", + "fqn": ["package", "my_other_source", "table1"], + "source_name": "my_other_source", + "source_description": "A source in a different schema.", + "loader": "Fivetran", + "identifier": "table1", + "quoting": {}, + "loaded_at_field": null, + "freshness": { + "warn_after": { + "count": null, + "period": null + }, + "error_after": { + "count": null, + "period": null + }, + "filter": null + }, + "external": null, + "description": "", + "columns": {}, + "meta": {}, + "source_meta": {}, + "tags": [], + "config": { + "enabled": true + }, + "patch_path": null, + "unrendered_config": {}, + "relation_name": "\"package\".\"my_other_source\".\"table1\"", + "created_at": 1728529440.4206917 } } } diff --git a/tests/resources/pyproject.toml b/tests/resources/pyproject.toml index cdd6cdb..14144b3 100644 --- a/tests/resources/pyproject.toml +++ b/tests/resources/pyproject.toml @@ -2,7 +2,7 @@ rule_namespaces = ["foo", "tests"] disabled_rules = ["foo.foo", "tests.bar"] fail_project_under = 7.5 -fail_any_model_under = 6.9 +fail_any_item_under = 6.9 [tool.dbt-score.badges] wip.icon = "🏗️" @@ -25,4 +25,4 @@ model_name="model2" [tool.dbt-score.rules."tests.rules.example.rule_test_example"] severity=4 -model_filter_names=["tests.rules.example.skip_model1"] +rule_filter_names=["tests.rules.example.skip_model1"] diff --git a/tests/rules/example.py b/tests/rules/example.py index 7c383ac..11b2f95 100644 --- a/tests/rules/example.py +++ b/tests/rules/example.py @@ -1,6 +1,6 @@ """Example rules.""" -from dbt_score import Model, RuleViolation, model_filter, rule +from dbt_score import Model, RuleViolation, rule, rule_filter @rule() @@ -8,7 +8,7 @@ def rule_test_example(model: Model) -> RuleViolation | None: """An example rule.""" -@model_filter +@rule_filter def skip_model1(model: Model) -> bool: """An example filter.""" return model.name != "model1" diff --git a/tests/test_cli.py b/tests/test_cli.py index 69f68f0..91d2bae 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -96,7 +96,7 @@ def test_fail_project_under(manifest_path): with patch("dbt_score.cli.Config._load_toml_file"): runner = CliRunner() result = runner.invoke( - lint, ["--manifest", manifest_path, "--fail_project_under", "10.0"] + lint, ["--manifest", manifest_path, "--fail-project-under", "10.0"] ) assert "model1" in result.output @@ -110,10 +110,10 @@ def test_fail_any_model_under(manifest_path): with patch("dbt_score.cli.Config._load_toml_file"): runner = CliRunner() result = runner.invoke( - lint, ["--manifest", manifest_path, "--fail_any_model_under", "10.0"] + lint, ["--manifest", manifest_path, "--fail-any-item-under", "10.0"] ) assert "model1" in result.output assert "model2" in result.output - assert "Error: model score too low, fail_any_model_under" in result.stdout + assert "Error: evaluable score too low, fail_any_item_under" in result.stdout assert result.exit_code == 1 diff --git a/tests/test_config.py b/tests/test_config.py index 634e99a..67a0e04 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,10 +26,10 @@ def test_load_valid_toml_file(valid_config_path): assert config.badge_config.second.icon == "2️⃣" assert config.badge_config.first.icon == "1️⃣" assert config.fail_project_under == 7.5 - assert config.fail_any_model_under == 6.9 + assert config.fail_any_item_under == 6.9 assert config.rules_config[ "tests.rules.example.rule_test_example" - ].model_filter_names == ["tests.rules.example.skip_model1"] + ].rule_filter_names == ["tests.rules.example.skip_model1"] def test_load_invalid_toml_file(caplog, invalid_config_path): diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 569124b..73573f4 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -53,11 +53,11 @@ def test_evaluation_low_medium_high( assert isinstance(evaluation.results[model2][rule_severity_high], RuleViolation) assert isinstance(evaluation.results[model2][rule_error], Exception) - assert mock_formatter.model_evaluated.call_count == 2 + assert mock_formatter.evaluable_evaluated.call_count == 5 assert mock_formatter.project_evaluated.call_count == 1 - assert mock_scorer.score_model.call_count == 2 - assert mock_scorer.score_aggregate_models.call_count == 1 + assert mock_scorer.score_evaluable.call_count == 5 + assert mock_scorer.score_aggregate_evaluables.call_count == 1 def test_evaluation_critical( @@ -181,14 +181,17 @@ def test_evaluation_rule_with_config( assert evaluation.results[model2][rule_with_config] is None -def test_evaluation_with_filter(manifest_path, default_config, rule_with_filter): +def test_evaluation_with_filter( + manifest_path, default_config, model_rule_with_filter, source_rule_with_filter +): """Test rule with filter.""" manifest_loader = ManifestLoader(manifest_path) mock_formatter = Mock() mock_scorer = Mock() rule_registry = RuleRegistry(default_config) - rule_registry._add_rule(rule_with_filter) + rule_registry._add_rule(model_rule_with_filter) + rule_registry._add_rule(source_rule_with_filter) # Ensure we get a valid Score object from the Mock mock_scorer.score_model.return_value = Score(10, "🥇") @@ -203,13 +206,23 @@ def test_evaluation_with_filter(manifest_path, default_config, rule_with_filter) model1 = manifest_loader.models[0] model2 = manifest_loader.models[1] + source1 = manifest_loader.sources[0] + source2 = manifest_loader.sources[1] + + assert model_rule_with_filter not in evaluation.results[model1] + assert isinstance(evaluation.results[model2][model_rule_with_filter], RuleViolation) - assert rule_with_filter not in evaluation.results[model1] - assert isinstance(evaluation.results[model2][rule_with_filter], RuleViolation) + assert source_rule_with_filter not in evaluation.results[source1] + assert isinstance( + evaluation.results[source2][source_rule_with_filter], RuleViolation + ) def test_evaluation_with_class_filter( - manifest_path, default_config, class_rule_with_filter + manifest_path, + default_config, + model_class_rule_with_filter, + source_class_rule_with_filter, ): """Test rule with filters and filtered rules defined by classes.""" manifest_loader = ManifestLoader(manifest_path) @@ -217,7 +230,8 @@ def test_evaluation_with_class_filter( mock_scorer = Mock() rule_registry = RuleRegistry(default_config) - rule_registry._add_rule(class_rule_with_filter) + rule_registry._add_rule(model_class_rule_with_filter) + rule_registry._add_rule(source_class_rule_with_filter) # Ensure we get a valid Score object from the Mock mock_scorer.score_model.return_value = Score(10, "🥇") @@ -232,6 +246,48 @@ def test_evaluation_with_class_filter( model1 = manifest_loader.models[0] model2 = manifest_loader.models[1] + source1 = manifest_loader.sources[0] + source2 = manifest_loader.sources[1] + + assert model_class_rule_with_filter not in evaluation.results[model1] + assert isinstance( + evaluation.results[model2][model_class_rule_with_filter], RuleViolation + ) + + assert source_class_rule_with_filter not in evaluation.results[source1] + assert isinstance( + evaluation.results[source2][source_class_rule_with_filter], RuleViolation + ) + + +def test_evaluation_with_models_and_sources( + manifest_path, default_config, decorator_rule, decorator_rule_source +): + """Test that model rules apply to models and source rules apply to sources.""" + manifest_loader = ManifestLoader(manifest_path) + mock_formatter = Mock() + mock_scorer = Mock() + + rule_registry = RuleRegistry(default_config) + rule_registry._add_rule(decorator_rule) + rule_registry._add_rule(decorator_rule_source) + + # Ensure we get a valid Score object from the Mock + mock_scorer.score_model.return_value = Score(10, "🥇") + + evaluation = Evaluation( + rule_registry=rule_registry, + manifest_loader=manifest_loader, + formatter=mock_formatter, + scorer=mock_scorer, + ) + evaluation.evaluate() + + model1 = manifest_loader.models[0] + source1 = manifest_loader.sources[0] + + assert decorator_rule in evaluation.results[model1] + assert decorator_rule_source not in evaluation.results[model1] - assert class_rule_with_filter not in evaluation.results[model1] - assert isinstance(evaluation.results[model2][class_rule_with_filter], RuleViolation) + assert decorator_rule_source in evaluation.results[source1] + assert decorator_rule not in evaluation.results[source1] diff --git a/tests/test_model_filter.py b/tests/test_model_filter.py deleted file mode 100644 index 86f1bf1..0000000 --- a/tests/test_model_filter.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Test model filters.""" - -from dbt_score.model_filter import ModelFilter, model_filter -from dbt_score.models import Model - - -def test_basic_filter(model1, model2): - """Test basic filter testing for a specific model.""" - - @model_filter - def only_model1(model: Model) -> bool: - """Some description.""" - return model.name == "model1" - - instance = only_model1() # since the decorator returns a Type - assert instance.description == "Some description." - assert instance.evaluate(model1) - assert not instance.evaluate(model2) - - -def test_class_filter(model1, model2): - """Test basic filter using class.""" - - class OnlyModel1(ModelFilter): - description = "Some description." - - def evaluate(self, model: Model) -> bool: - return model.name == "model1" - - instance = OnlyModel1() - assert instance.description == "Some description." - assert instance.evaluate(model1) - assert not instance.evaluate(model2) diff --git a/tests/test_models.py b/tests/test_models.py index 4300de8..b0fac9d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -22,6 +22,15 @@ def test_manifest_load(mock_read_text, raw_manifest): assert loader.models[0].tests[0].name == "test2" assert loader.models[0].columns[0].tests[0].name == "test1" + assert len(loader.sources) == len( + [ + source + for source in raw_manifest["sources"].values() + if source["package_name"] == raw_manifest["metadata"]["project_name"] + ] + ) + assert loader.sources[0].tests[0].name == "source_test1" + @patch("dbt_score.models.Path.read_text") def test_manifest_select_models_simple(mock_read_text, raw_manifest): @@ -52,4 +61,4 @@ def test_manifest_no_model(mock_dbt_ls, mock_read_text, raw_manifest, caplog): manifest_loader = ManifestLoader(Path("some.json"), select=["non_existing"]) assert len(manifest_loader.models) == 0 - assert "No model found" in caplog.text + assert "Nothing to evaluate!" in caplog.text diff --git a/tests/test_more_itertools.py b/tests/test_more_itertools.py new file mode 100644 index 0000000..2f29036 --- /dev/null +++ b/tests/test_more_itertools.py @@ -0,0 +1,22 @@ +"""Tests for vendored more_itertools functions.""" +from dbt_score.more_itertools import first_true + + +class TestFirstTrue: + """Tests for `first_true`.""" + + def test_something_true(self): + """Test with no keyword arguments.""" + assert first_true(range(10), 1) + + def test_nothing_true(self): + """Test default return value.""" + assert first_true([0, 0, 0]) is None + + def test_default(self): + """Test with a default keyword.""" + assert first_true([0, 0, 0], default="!") == "!" + + def test_pred(self): + """Test with a custom predicate.""" + assert first_true([2, 4, 6], pred=lambda x: x % 3 == 0) == 6 diff --git a/tests/test_rule.py b/tests/test_rule.py index 1ed3077..ff9efb5 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -1,7 +1,8 @@ """Test rule.""" import pytest -from dbt_score import Model, Rule, RuleViolation, Severity, rule +from dbt_score import Model, Rule, RuleViolation, Severity, Source, rule +from dbt_score.rule_filter import RuleFilter, rule_filter def test_rule_decorator_and_class( @@ -49,20 +50,105 @@ def test_missing_description_rule_class(): class BadRule(Rule): """Bad example rule.""" - def evaluate(self, model: Model) -> RuleViolation | None: + def evaluate(self, model: Model) -> RuleViolation | None: # type: ignore[override] """Evaluate model.""" return None def test_missing_evaluate_rule_class(model1): """Test missing evaluate implementation in rule class.""" + with pytest.raises(TypeError): - class BadRule(Rule): - """Bad example rule.""" + class BadRule(Rule): + """Bad example rule.""" + + description = "Description of the rule." + + +@pytest.mark.parametrize( + "rule_fixture,expected_type", + [ + ("decorator_rule", Model), + ("decorator_rule_no_parens", Model), + ("decorator_rule_args", Model), + ("class_rule", Model), + ("decorator_rule_source", Source), + ("decorator_rule_no_parens_source", Source), + ("decorator_rule_args_source", Source), + ("class_rule_source", Source), + ], +) +def test_rule_introspects_its_resource_type(request, rule_fixture, expected_type): + """Test that each rule is aware of the resource-type it is evaluated against.""" + rule = request.getfixturevalue(rule_fixture) + assert rule().resource_type is expected_type + + +class TestRuleFilterValidation: + """Tests that a rule filter matches resource-type to the rule it's attached to.""" + + @pytest.fixture + def source_filter_no_parens(self): + """Example source filter with bare decorator.""" + + @rule_filter + def source_filter(source: Source) -> bool: + """Description.""" + return False + + return source_filter() + + @pytest.fixture + def source_filter_parens(self): + """Example source filter with decorator and parens.""" + + @rule_filter() + def source_filter(source: Source) -> bool: + """Description.""" + return False + + return source_filter() + + @pytest.fixture + def source_filter_class(self): + """Example class-based source filter.""" + + class SourceFilter(RuleFilter): + description = "Description" + + def evaluate(self, source: Source) -> bool: # type: ignore[override] + return False + + return SourceFilter + + @pytest.mark.parametrize( + "rule_filter_fixture", + ["source_filter_no_parens", "source_filter_parens", "source_filter_class"], + ) + def test_rule_filter_must_match_resource_type_as_rule( + self, request, rule_filter_fixture + ): + """Tests that rules can't be created with filters of incorrect resource-type.""" + rule_filter = request.getfixturevalue(rule_filter_fixture) + + with pytest.raises(TypeError) as excinfo: + + @rule(rule_filters={rule_filter}) + def model_always_passes(model: Model) -> RuleViolation | None: + """Description.""" + pass + + assert "Mismatched resource_type on filter" in str(excinfo.value) + assert "Expected Model, but got Source" in str(excinfo.value) + + with pytest.raises(TypeError): - description = "Description of the rule." + class ModelAlwaysPasses(Rule): + description = "Description." + rule_filters = frozenset([rule_filter]) - rule = BadRule() + def evaluate(self, model: Model) -> RuleViolation | None: # type: ignore[override] + pass - with pytest.raises(NotImplementedError): - rule.evaluate(model1) + assert "Mismatched resource_type on filter" in str(excinfo.value) + assert "Expected Model, but got Source" in str(excinfo.value) diff --git a/tests/test_rule_filter.py b/tests/test_rule_filter.py new file mode 100644 index 0000000..ae6909e --- /dev/null +++ b/tests/test_rule_filter.py @@ -0,0 +1,93 @@ +"""Test rule filters.""" +import pytest +from dbt_score.models import Model, Source +from dbt_score.rule_filter import RuleFilter, rule_filter + + +def test_basic_filter(model1, model2): + """Test basic filter testing for a specific model.""" + + @rule_filter + def only_model1(model: Model) -> bool: + """Some description.""" + return model.name == "model1" + + instance = only_model1() # since the decorator returns a Type + assert instance.description == "Some description." + assert instance.evaluate(model1) + assert not instance.evaluate(model2) + + +def test_basic_filter_with_sources(source1, source2): + """Test basic filter testing for a specific source.""" + + @rule_filter + def only_source1(source: Source) -> bool: + """Some description.""" + return source.name == "table1" + + instance = only_source1() # since the decorator returns a Type + assert instance.description == "Some description." + assert instance.evaluate(source1) + assert not instance.evaluate(source2) + + +def test_class_filter(model1, model2): + """Test basic filter using class.""" + + class OnlyModel1(RuleFilter): + description = "Some description." + + def evaluate(self, model: Model) -> bool: # type: ignore[override] + return model.name == "model1" + + instance = OnlyModel1() + assert instance.description == "Some description." + assert instance.evaluate(model1) + assert not instance.evaluate(model2) + + +def test_class_filter_with_sources(source1, source2): + """Test basic filter using class.""" + + class OnlySource1(RuleFilter): + description = "Some description." + + def evaluate(self, source: Source) -> bool: # type: ignore[override] + return source.name == "table1" + + instance = OnlySource1() + assert instance.description == "Some description." + assert instance.evaluate(source1) + assert not instance.evaluate(source2) + + +def test_missing_description_rule_filter(): + """Test missing description in filter decorator.""" + with pytest.raises(AttributeError): + + @rule_filter() + def example_filter(model: Model) -> bool: + return True + + +def test_missing_description_rule_class(): + """Test missing description in filter class.""" + with pytest.raises(AttributeError): + + class BadFilter(RuleFilter): + """Bad example filter.""" + + def evaluate(self, model: Model) -> bool: # type: ignore[override] + """Evaluate filter.""" + return True + + +def test_missing_evaluate_rule_class(model1): + """Test missing evaluate implementation in filter class.""" + with pytest.raises(TypeError): + + class BadFilter(RuleFilter): + """Bad example filter.""" + + description = "Description of the rule." diff --git a/tests/test_rule_registry.py b/tests/test_rule_registry.py index 4e7abb2..343d98b 100644 --- a/tests/test_rule_registry.py +++ b/tests/test_rule_registry.py @@ -15,7 +15,7 @@ def test_rule_registry_discovery(default_config): "tests.rules.example.rule_test_example", "tests.rules.nested.example.rule_test_nested_example", ] - assert list(r._model_filters.keys()) == ["tests.rules.example.skip_model1"] + assert list(r._rule_filters.keys()) == ["tests.rules.example.skip_model1"] def test_disabled_rule_registry_discovery(): @@ -55,7 +55,7 @@ def test_rule_registry_core_rules(default_config): assert len(r.rules) > 0 -def test_rule_registry_model_filters(valid_config_path, model1, model2): +def test_rule_registry_rule_filters(valid_config_path, model1, model2): """Test config filters are loaded.""" config = Config() config._load_toml_file(str(valid_config_path)) diff --git a/tests/test_scoring.py b/tests/test_scoring.py index e47a493..356992c 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -8,16 +8,18 @@ def test_scorer_model_no_results(default_config): """Test scorer with a model without any result.""" scorer = Scorer(config=default_config) - assert scorer.score_model({}).value == 10.0 + assert scorer.score_evaluable({}).value == 10.0 def test_scorer_model_severity_low(default_config, rule_severity_low): """Test scorer with a model and one low severity rule.""" scorer = Scorer(config=default_config) - assert scorer.score_model({rule_severity_low: None}).value == 10.0 - assert scorer.score_model({rule_severity_low: Exception()}).value == 10.0 + assert scorer.score_evaluable({rule_severity_low: None}).value == 10.0 + assert scorer.score_evaluable({rule_severity_low: Exception()}).value == 10.0 assert ( - round(scorer.score_model({rule_severity_low: RuleViolation("error")}).value, 2) + round( + scorer.score_evaluable({rule_severity_low: RuleViolation("error")}).value, 2 + ) == 6.67 ) @@ -25,11 +27,14 @@ def test_scorer_model_severity_low(default_config, rule_severity_low): def test_scorer_model_severity_medium(default_config, rule_severity_medium): """Test scorer with a model and one medium severity rule.""" scorer = Scorer(config=default_config) - assert scorer.score_model({rule_severity_medium: None}).value == 10.0 - assert scorer.score_model({rule_severity_medium: Exception()}).value == 10.0 + assert scorer.score_evaluable({rule_severity_medium: None}).value == 10.0 + assert scorer.score_evaluable({rule_severity_medium: Exception()}).value == 10.0 assert ( round( - scorer.score_model({rule_severity_medium: RuleViolation("error")}).value, 2 + scorer.score_evaluable( + {rule_severity_medium: RuleViolation("error")} + ).value, + 2, ) == 3.33 ) @@ -38,18 +43,21 @@ def test_scorer_model_severity_medium(default_config, rule_severity_medium): def test_scorer_model_severity_high(default_config, rule_severity_high): """Test scorer with a model and one high severity rule.""" scorer = Scorer(config=default_config) - assert scorer.score_model({rule_severity_high: None}).value == 10.0 - assert scorer.score_model({rule_severity_high: Exception()}).value == 10.0 - assert scorer.score_model({rule_severity_high: RuleViolation("error")}).value == 0.0 + assert scorer.score_evaluable({rule_severity_high: None}).value == 10.0 + assert scorer.score_evaluable({rule_severity_high: Exception()}).value == 10.0 + assert ( + scorer.score_evaluable({rule_severity_high: RuleViolation("error")}).value + == 0.0 + ) def test_scorer_model_severity_critical(default_config, rule_severity_critical): """Test scorer with a model and one critical severity rule.""" scorer = Scorer(config=default_config) - assert scorer.score_model({rule_severity_critical: None}).value == 10.0 - assert scorer.score_model({rule_severity_critical: Exception()}).value == 10.0 + assert scorer.score_evaluable({rule_severity_critical: None}).value == 10.0 + assert scorer.score_evaluable({rule_severity_critical: Exception()}).value == 10.0 assert ( - scorer.score_model({rule_severity_critical: RuleViolation("error")}).value + scorer.score_evaluable({rule_severity_critical: RuleViolation("error")}).value == 0.0 ) @@ -60,7 +68,7 @@ def test_scorer_model_severity_critical_overwrites( """Test scorer with a model and multiple rules including one critical.""" scorer = Scorer(config=default_config) assert ( - scorer.score_model( + scorer.score_evaluable( {rule_severity_low: None, rule_severity_critical: RuleViolation("error")} ).value == 0.0 @@ -74,7 +82,7 @@ def test_scorer_model_multiple_rules( scorer = Scorer(config=default_config) assert ( round( - scorer.score_model( + scorer.score_evaluable( { rule_severity_low: None, rule_severity_medium: Exception(), @@ -88,7 +96,7 @@ def test_scorer_model_multiple_rules( assert ( round( - scorer.score_model( + scorer.score_evaluable( { rule_severity_low: Exception(), rule_severity_medium: RuleViolation("error"), @@ -102,7 +110,7 @@ def test_scorer_model_multiple_rules( assert ( round( - scorer.score_model( + scorer.score_evaluable( { rule_severity_low: RuleViolation("error"), rule_severity_medium: Exception(), @@ -118,39 +126,39 @@ def test_scorer_model_multiple_rules( def test_scorer_aggregate_empty(default_config): """Test scorer aggregation with no results.""" scorer = Scorer(config=default_config) - assert scorer.score_aggregate_models([]).value == 10.0 + assert scorer.score_aggregate_evaluables([]).value == 10.0 def test_scorer_aggregate_with_0(default_config): """Test scorer aggregation with one result that is 0.0.""" scorer = Scorer(config=default_config) scores = [Score(1.0, ""), Score(5.0, ""), Score(0.0, "")] - assert scorer.score_aggregate_models(scores).value == 0.0 + assert scorer.score_aggregate_evaluables(scores).value == 0.0 def test_scorer_aggregate_single(default_config): """Test scorer aggregation with a single results.""" scorer = Scorer(config=default_config) - assert scorer.score_aggregate_models([Score(4.2, "")]).value == 4.2 + assert scorer.score_aggregate_evaluables([Score(4.2, "")]).value == 4.2 def test_scorer_aggregate_multiple(default_config): """Test scorer aggregation with multiple results.""" scorer = Scorer(config=default_config) assert ( - scorer.score_aggregate_models( + scorer.score_aggregate_evaluables( [Score(1.0, ""), Score(1.0, ""), Score(1.0, "")] ).value == 1.0 ) assert ( - scorer.score_aggregate_models( + scorer.score_aggregate_evaluables( [Score(1.0, ""), Score(7.4, ""), Score(4.2, "")] ).value == 4.2 ) assert ( - scorer.score_aggregate_models( + scorer.score_aggregate_evaluables( [Score(0.0, ""), Score(0.0, ""), Score(0.0, "")] ).value == 0.0