Skip to content

Commit

Permalink
Add more parameters and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jochemvandooren committed Mar 18, 2024
1 parent d1ddd4b commit 5e4a4bc
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 54 deletions.
107 changes: 74 additions & 33 deletions src/dbt_score/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ class Test:
type: The type of the test, e.g. `unique`.
kwargs: The kwargs of the test.
tags: The list of tags attached to the test.
_raw_values: The raw values of the test in the manifest.
"""

name: str
type: str
kwargs: dict[str, Any] = field(default_factory=dict)
tags: list[str] = field(default_factory=list)
_raw_values: dict[str, Any] = field(default_factory=dict)

@classmethod
def from_node(cls, test_node: dict[str, Any]) -> "Test":
Expand All @@ -45,6 +47,7 @@ def from_node(cls, test_node: dict[str, Any]) -> "Test":
kwargs=test_node["test_metadata"].get("kwargs", {}),
tags=test_node.get("tags", []),
)
test._raw_values = test_node
return test


Expand All @@ -55,16 +58,51 @@ class Column:
Attributes:
name: The name of the column.
description: The description of the column.
data_type: The data type of the column.
meta: The metadata attached to the column.
constraints: The list of constraints attached to the column.
tags: The list of tags attached to the column.
tests: The list of tests attached to the column.
_raw_values: The raw values of the column as defined in the node.
_raw_test_values: The raw test values of the column as defined in the node.
"""

name: str
description: str
data_type: str | None = None
meta: dict[str, Any] = field(default_factory=dict)
constraints: list[Constraint] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
tests: list[Test] = field(default_factory=list)
_raw_values: dict[str, Any] = field(default_factory=dict)
_raw_test_values: list[dict[str, Any]] = field(default_factory=list)

@classmethod
def from_node_values(
cls, values: dict[str, Any], test_values: list[dict[str, Any]]
) -> "Column":
"""Create a column object from raw values."""
column = cls(
name=values["name"],
description=values["description"],
data_type=values["data_type"],
meta=values["meta"],
constraints=[
Constraint(
name=constraint["name"],
type=constraint["type"],
expression=constraint["expression"],
)
for constraint in values["constraints"]
],
tags=values["tags"],
tests=[Test.from_node(test) for test in test_values],
)

column._raw_values = values
column._raw_test_values = test_values

return column


@dataclass
Expand All @@ -74,34 +112,44 @@ class Model:
Attributes:
unique_id: The id of the model, e.g. `model.package.model_name`.
name: The name of the model.
relation_name: The relation name of the model, e.g. `db.schema.model_name`.
description: The full description of the model.
patch_path: The yml path of the model, e.g. `package://model_dir/dir/file.yml`.
original_file_path: The sql path of the model, `e.g. model_dir/dir/file.sql`.
config: The config of the model.
meta: The meta of the model.
columns: The list of columns of the model.
package_name: The package name of the model.
database: The database name of the model.
schema: The schema name of the model.
raw_code: The raw code of the model.
alias: The alias of the model.
patch_path: The yml path of the model, e.g. `package://model_dir/dir/file.yml`.
tags: The list of tags attached to the model.
tests: The list of tests attached to the model.
depends_on: Dictionary of models/sources/macros that the model depends on.
_node_values: The raw values of the model in the manifest.
_test_values: The raw test values of the model in the manifest.
"""

unique_id: str
name: str
relation_name: str
description: str
patch_path: str
original_file_path: str
config: dict[str, Any]
meta: dict[str, Any]
columns: list[Column]
package_name: str
database: str
schema: str
raw_code: str
alias: str | None = None
patch_path: str | None = None
tags: list[str] = field(default_factory=list)
tests: list[Test] = field(default_factory=list)
depends_on: dict[str, list[str]] = field(default_factory=dict)
_node_values: dict[str, Any] = field(default_factory=dict)
_test_values: list[dict[str, Any]] = field(default_factory=list)

def get_column(self, column_name: str) -> Column | None:
"""Get a column by name."""
Expand All @@ -113,59 +161,52 @@ def get_column(self, column_name: str) -> Column | None:

@staticmethod
def _get_columns(
node_values: dict[str, Any], tests_values: list[dict[str, Any]]
node_values: dict[str, Any], test_values: list[dict[str, Any]]
) -> list[Column]:
"""Get columns from a node and it's tests in the manifest."""
columns = [
Column(
name=values.get("name"),
description=values.get("description"),
constraints=[
Constraint(
name=constraint.get("name"),
type=constraint.get("type"),
expression=constraint.get("expression"),
)
for constraint in values.get("constraints", [])
],
tags=values.get("tags", []),
tests=[
Test.from_node(test)
for test in tests_values
if test["test_metadata"].get("kwargs", {}).get("column_name")
== values.get("name")
"""Get columns from a node and its tests in the manifest."""
return [
Column.from_node_values(
values,
[
test
for test in test_values
if test["test_metadata"]["kwargs"].get("column_name") == name
],
)
for name, values in node_values.get("columns", {}).items()
]
return columns

@classmethod
def from_node(
cls, node_values: dict[str, Any], tests_values: list[dict[str, Any]]
cls, node_values: dict[str, Any], test_values: list[dict[str, Any]]
) -> "Model":
"""Create a model object from a node and it's tests in the manifest."""
model = cls(
unique_id=node_values["unique_id"],
name=node_values["name"],
description=node_values.get("description", ""),
patch_path=node_values["patch_path"],
relation_name=node_values["relation_name"],
description=node_values["description"],
original_file_path=node_values["original_file_path"],
config=node_values.get("config", {}),
meta=node_values.get("meta", {}),
columns=cls._get_columns(node_values, tests_values),
config=node_values["config"],
meta=node_values["meta"],
columns=cls._get_columns(node_values, test_values),
package_name=node_values["package_name"],
database=node_values["database"],
schema=node_values["schema"],
tags=node_values.get("tags", []),
alias=node_values["alias"],
patch_path=node_values["patch_path"],
tags=node_values["tags"],
tests=[
Test.from_node(test)
for test in tests_values
if not test["test_metadata"].get("kwargs", {}).get("column_name")
for test in test_values
if not test["test_metadata"]["kwargs"].get("column_name")
],
depends_on=node_values.get("depends_on", {}),
depends_on=node_values["depends_on"],
)

model._node_values = node_values
model._test_values = test_values

return model


Expand Down
15 changes: 5 additions & 10 deletions src/dbt_score/rule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Rule definitions."""

import functools
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Type
from typing import Callable, Type

from dbt_score.models import Model

Expand Down Expand Up @@ -36,10 +35,9 @@ def __init_subclass__(cls, **kwargs) -> None: # type: ignore
if not hasattr(cls, "description"):
raise TypeError("Subclass must define class attribute `description`.")

@classmethod
def evaluate(cls, model: Model) -> RuleViolation | None:
def evaluate(self, model: Model) -> RuleViolation | None:
"""Evaluates the rule."""
raise NotImplementedError("Subclass must implement class method `evaluate`.")
raise NotImplementedError("Subclass must implement method `evaluate`.")


def rule(
Expand All @@ -58,10 +56,7 @@ def rule(
def decorator_rule(
func: Callable[[Model], RuleViolation | None],
) -> Type[Rule]:
@functools.wraps(func)
def wrapper_rule(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

"""Decorator function."""
if func.__doc__ is None and description is None:
raise TypeError("Rule must define `description` or `func.__doc__`.")

Expand All @@ -77,7 +72,7 @@ def wrapper_rule(*args: Any, **kwargs: Any) -> Any:
{
"description": rule_description,
"severity": severity,
"evaluate": wrapper_rule,
"evaluate": func,
},
)

Expand Down
1 change: 1 addition & 0 deletions src/dbt_score/rules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Rules."""
20 changes: 9 additions & 11 deletions src/dbt_score/rules/example_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ class ComplexRule(Rule):

description = "Example of a complex rule."

@classmethod
def preprocess(cls) -> int:
def preprocess(self) -> int:
"""Preprocessing."""
return 1
return len(self.description)

@classmethod
def evaluate(cls, model: Model) -> RuleViolation | None:
def evaluate(self, model: Model) -> RuleViolation | None:
"""Evaluate model."""
x = cls.preprocess()
x = self.preprocess()

if x:
return RuleViolation(str(x))
Expand All @@ -26,7 +24,7 @@ def evaluate(cls, model: Model) -> RuleViolation | None:


@rule()
def has_owner(model: Model) -> RuleViolation | None:
def has_owner(self, model: Model) -> RuleViolation | None:
"""A model should have an owner defined."""
if "owner" not in model.meta:
return RuleViolation("Define the owner of the model in the meta section.")
Expand All @@ -35,7 +33,7 @@ def has_owner(model: Model) -> RuleViolation | None:


@rule()
def has_primary_key(model: Model) -> RuleViolation | None:
def has_primary_key(self, model: Model) -> RuleViolation | None:
"""A model should have a primary key defined, unless it's a view."""
if not model.config.get("materialized") == "picnic_view":
has_pk = False
Expand All @@ -51,7 +49,7 @@ def has_primary_key(model: Model) -> RuleViolation | None:


@rule()
def primary_key_has_uniqueness_test(model: Model) -> RuleViolation | None:
def primary_key_has_uniqueness_test(self, model: Model) -> RuleViolation | None:
"""Primary key columns should have a uniqueness test defined."""
columns_with_pk = []
if model.config.get("materialized") == "view":
Expand All @@ -68,7 +66,7 @@ def primary_key_has_uniqueness_test(model: Model) -> RuleViolation | None:


@rule()
def columns_have_description(model: Model) -> RuleViolation | None:
def columns_have_description(self, model: Model) -> RuleViolation | None:
"""All columns of a model should have a description."""
invalid_column_names = [
column.name for column in model.columns if not column.description
Expand All @@ -83,7 +81,7 @@ def columns_have_description(model: Model) -> RuleViolation | None:


@rule(description="A model should have at least one test defined.")
def has_test(model: Model) -> RuleViolation | None:
def has_test(self, model: Model) -> RuleViolation | None:
"""A model should have at least one model-level or column-level test defined.
This does not include singular tests, which are tests defined in a separate .sql
Expand Down

0 comments on commit 5e4a4bc

Please sign in to comment.