diff --git a/.changes/unreleased/Features-20230906-234741.yaml b/.changes/unreleased/Features-20230906-234741.yaml new file mode 100644 index 00000000000..ca94f1fc6c5 --- /dev/null +++ b/.changes/unreleased/Features-20230906-234741.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support config with tags & meta for unit tests +time: 2023-09-06T23:47:41.059915-04:00 +custom: + Author: michelleark + Issue: "8294" diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index ecaa9427603..2de38b51d56 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -425,6 +425,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": snapshots: Dict[str, Any] sources: Dict[str, Any] tests: Dict[str, Any] + unit_tests: Dict[str, Any] metrics: Dict[str, Any] semantic_models: Dict[str, Any] exposures: Dict[str, Any] @@ -436,6 +437,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": snapshots = cfg.snapshots sources = cfg.sources tests = cfg.tests + unit_tests = cfg.unit_tests metrics = cfg.metrics semantic_models = cfg.semantic_models exposures = cfg.exposures @@ -493,6 +495,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": query_comment=query_comment, sources=sources, tests=tests, + unit_tests=unit_tests, metrics=metrics, semantic_models=semantic_models, exposures=exposures, @@ -600,6 +603,7 @@ class Project: snapshots: Dict[str, Any] sources: Dict[str, Any] tests: Dict[str, Any] + unit_tests: Dict[str, Any] metrics: Dict[str, Any] semantic_models: Dict[str, Any] exposures: Dict[str, Any] @@ -676,6 +680,7 @@ def to_project_config(self, with_packages=False): "snapshots": self.snapshots, "sources": self.sources, "tests": self.tests, + "unit_tests": self.unit_tests, "metrics": self.metrics, "semantic-models": self.semantic_models, "exposures": self.exposures, diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index 3156aa31878..0226fe90d47 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -166,6 +166,7 @@ def from_parts( query_comment=project.query_comment, sources=project.sources, tests=project.tests, + unit_tests=project.unit_tests, metrics=project.metrics, semantic_models=project.semantic_models, exposures=project.exposures, @@ -322,6 +323,7 @@ def get_resource_config_paths(self) -> Dict[str, PathSet]: "snapshots": self._get_config_paths(self.snapshots), "sources": self._get_config_paths(self.sources), "tests": self._get_config_paths(self.tests), + "unit_tests": self._get_config_paths(self.unit_tests), "metrics": self._get_config_paths(self.metrics), "semantic_models": self._get_config_paths(self.semantic_models), "exposures": self._get_config_paths(self.exposures), diff --git a/core/dbt/context/context_config.py b/core/dbt/context/context_config.py index f5c32ff0f81..f766c2ce9dd 100644 --- a/core/dbt/context/context_config.py +++ b/core/dbt/context/context_config.py @@ -49,6 +49,8 @@ def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]: model_configs = unrendered.get("semantic_models") elif resource_type == NodeType.Exposure: model_configs = unrendered.get("exposures") + elif resource_type == NodeType.Unit: + model_configs = unrendered.get("unit_tests") else: model_configs = unrendered.get("models") if model_configs is None: @@ -76,6 +78,8 @@ def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]: model_configs = self.project.semantic_models elif resource_type == NodeType.Exposure: model_configs = self.project.exposures + elif resource_type == NodeType.Unit: + model_configs = self.project.unit_tests else: model_configs = self.project.models return model_configs diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 9304e0496f5..b2d10c8d32e 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1504,7 +1504,7 @@ def defer_relation(self) -> Optional[RelationProxy]: class UnitTestContext(ModelContext): model: UnitTestNode - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the overriden unit test environment variable named 'var'. diff --git a/core/dbt/contracts/graph/model_config.py b/core/dbt/contracts/graph/model_config.py index 5e2d5c2020d..b8a1e5f0aa4 100644 --- a/core/dbt/contracts/graph/model_config.py +++ b/core/dbt/contracts/graph/model_config.py @@ -640,6 +640,18 @@ def finalize_and_validate(self): return self.from_dict(data) +@dataclass +class UnitTestConfig(BaseConfig): + tags: Union[str, List[str]] = field( + default_factory=list_str, + metadata=metas(ShowBehavior.Hide, MergeBehavior.Append, CompareBehavior.Exclude), + ) + meta: Dict[str, Any] = field( + default_factory=dict, + metadata=MergeBehavior.Update.meta(), + ) + + RESOURCE_TYPES: Dict[NodeType, Type[BaseConfig]] = { NodeType.Metric: MetricConfig, NodeType.SemanticModel: SemanticModelConfig, @@ -650,6 +662,7 @@ def finalize_and_validate(self): NodeType.Unit: TestConfig, NodeType.Model: NodeConfig, NodeType.Snapshot: SnapshotConfig, + NodeType.Unit: UnitTestConfig, } diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 3eb3015677d..8023ceb3e5f 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -40,7 +40,10 @@ from dbt.contracts.graph.node_args import ModelNodeArgs from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin from dbt.events.functions import warn_or_error -from dbt.exceptions import ParsingError, ContractBreakingChangeError +from dbt.exceptions import ( + ParsingError, + ContractBreakingChangeError, +) from dbt.events.types import ( SeedIncreased, SeedExceedsLimitSamePath, @@ -73,6 +76,7 @@ EmptySnapshotConfig, SnapshotConfig, SemanticModelConfig, + UnitTestConfig, ) @@ -1062,17 +1066,22 @@ class UnitTestNode(CompiledNode): @dataclass class UnitTestDefinition(GraphNode): model: str - attached_node: str given: Sequence[InputFixture] expect: List[Dict[str, Any]] description: str = "" overrides: Optional[UnitTestOverrides] = None depends_on: DependsOn = field(default_factory=DependsOn) + config: UnitTestConfig = field(default_factory=UnitTestConfig) @property def depends_on_nodes(self): return self.depends_on.nodes + @property + def tags(self) -> List[str]: + tags = self.config.tags + return [tags] if isinstance(tags, str) else tags + # ==================================== # Snapshot node @@ -1699,10 +1708,6 @@ def primary_entity_reference(self) -> Optional[EntityReference]: else None ) - @property - def group(self): - return None - # ==================================== # Patches diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index f6024901a22..1ff359a7a7c 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -755,6 +755,7 @@ class UnparsedUnitTestDefinition(dbtClassMixin): expect: List[Dict[str, Any]] description: str = "" overrides: Optional[UnitTestOverrides] = None + config: Dict[str, Any] = field(default_factory=dict) @dataclass diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index 9e09fd56692..568aaabed9d 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -213,6 +213,7 @@ class Project(dbtClassMixin, Replaceable): analyses: Dict[str, Any] = field(default_factory=dict) sources: Dict[str, Any] = field(default_factory=dict) tests: Dict[str, Any] = field(default_factory=dict) + unit_tests: Dict[str, Any] = field(default_factory=dict) metrics: Dict[str, Any] = field(default_factory=dict) semantic_models: Dict[str, Any] = field(default_factory=dict) exposures: Dict[str, Any] = field(default_factory=dict) diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 49bc6e2f9a4..b011c3f32ba 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -1,15 +1,23 @@ +from typing import List, Set, Dict, Any + +from dbt.config import RuntimeConfig +from dbt.context.context_config import ContextConfig +from dbt.context.providers import generate_parse_exposure, get_rendered +from dbt.contracts.files import FileHash +from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.model_config import NodeConfig -from dbt_extractor import py_extract_from_source # type: ignore -from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite from dbt.contracts.graph.nodes import ( ModelNode, UnitTestNode, RefArgs, UnitTestDefinition, DependsOn, + UnitTestConfig, ) -from dbt.config import RuntimeConfig -from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite +from dbt.exceptions import ParsingError +from dbt.graph import UniqueId +from dbt.node_types import NodeType from dbt.parser.schemas import ( SchemaParser, YamlBlock, @@ -17,23 +25,10 @@ JSONValidationError, YamlParseDictError, YamlReader, + ParseResult, ) -from dbt.node_types import NodeType - -from dbt.exceptions import ( - ParsingError, -) - -from dbt.contracts.files import FileHash -from dbt.graph import UniqueId - -from dbt.context.providers import generate_parse_exposure, get_rendered -from typing import List, Set from dbt.utils import get_pseudo_test_path - - -def _is_model_node(node_id, manifest): - return manifest.nodes[node_id].resource_type == NodeType.Model +from dbt_extractor import py_extract_from_source # type: ignore class UnitTestManifestLoader: @@ -176,43 +171,77 @@ def _get_original_input_node(self, input: str): class UnitTestParser(YamlReader): - def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock): + def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None: super().__init__(schema_parser, yaml, "unit") self.schema_parser = schema_parser self.yaml = yaml - def parse(self): + def parse(self) -> ParseResult: for data in self.get_key_dicts(): - try: - UnparsedUnitTestSuite.validate(data) - unparsed = UnparsedUnitTestSuite.from_dict(data) - except (ValidationError, JSONValidationError) as exc: - raise YamlParseDictError(self.yaml.path, self.key, data, exc) - package_name = self.project.project_name - - actual_node = self.manifest.ref_lookup.perform_lookup( - f"model.{package_name}.{unparsed.model}", self.manifest - ) - if not actual_node: - raise ParsingError( - "Unable to find model {unparsed.model} for unit tests in {self.yaml.path.original_file_path}" + unit_test_suite = self._get_unit_test_suite(data) + model_name_split = unit_test_suite.model.split() + tested_model_node = self._find_tested_model_node(unit_test_suite) + + for test in unit_test_suite.tests: + unit_test_case_unique_id = ( + f"unit.{self.project.project_name}.{unit_test_suite.model}.{test.name}" ) - for test in unparsed.tests: - unit_test_case_unique_id = f"unit.{package_name}.{test.name}.{unparsed.model}" - unit_test_case = UnitTestDefinition( + unit_test_fqn = [self.project.project_name] + model_name_split + [test.name] + unit_test_config = self._build_unit_test_config(unit_test_fqn, test.config) + + unit_test_definition = UnitTestDefinition( name=test.name, - model=unparsed.model, + model=unit_test_suite.model, resource_type=NodeType.Unit, - package_name=package_name, + package_name=self.project.project_name, path=self.yaml.path.relative_path, original_file_path=self.yaml.path.original_file_path, unique_id=unit_test_case_unique_id, - attached_node=actual_node.unique_id, given=test.given, expect=test.expect, description=test.description, overrides=test.overrides, - depends_on=DependsOn(nodes=[actual_node.unique_id]), - fqn=[package_name, test.name], + depends_on=DependsOn(nodes=[tested_model_node.unique_id]), + fqn=unit_test_fqn, + config=unit_test_config, ) - self.manifest.add_unit_test(self.yaml.file, unit_test_case) + self.manifest.add_unit_test(self.yaml.file, unit_test_definition) + + return ParseResult() + + def _get_unit_test_suite(self, data: Dict[str, Any]) -> UnparsedUnitTestSuite: + try: + UnparsedUnitTestSuite.validate(data) + return UnparsedUnitTestSuite.from_dict(data) + except (ValidationError, JSONValidationError) as exc: + raise YamlParseDictError(self.yaml.path, self.key, data, exc) + + def _find_tested_model_node(self, unit_test_suite: UnparsedUnitTestSuite) -> ModelNode: + package_name = self.project.project_name + model_name_split = unit_test_suite.model.split() + model_name = model_name_split[0] + model_version = model_name_split[1] if len(model_name_split) == 2 else None + + tested_node = self.manifest.ref_lookup.find( + model_name, package_name, model_version, self.manifest + ) + if not tested_node: + raise ParsingError( + f"Unable to find model '{package_name}.{unit_test_suite.model}' for unit tests in {self.yaml.path.original_file_path}" + ) + + return tested_node + + def _build_unit_test_config( + self, unit_test_fqn: List[str], config_dict: Dict[str, Any] + ) -> UnitTestConfig: + config = ContextConfig( + self.schema_parser.root_project, + unit_test_fqn, + NodeType.Unit, + self.schema_parser.project.project_name, + ) + unit_test_config_dict = config.build_config_dict(patch_config_dict=config_dict) + unit_test_config_dict = self.render_entry(unit_test_config_dict) + + return UnitTestConfig.from_dict(unit_test_config_dict) diff --git a/schemas/dbt/manifest/v11.json b/schemas/dbt/manifest/v11.json index 25fe893a5fa..78e15252292 100644 --- a/schemas/dbt/manifest/v11.json +++ b/schemas/dbt/manifest/v11.json @@ -5756,6 +5756,38 @@ "input" ] }, + "UnitTestConfig": { + "type": "object", + "title": "UnitTestConfig", + "properties": { + "_extra": { + "type": "object", + "propertyNames": { + "type": "string" + } + }, + "tags": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ] + }, + "meta": { + "type": "object", + "propertyNames": { + "type": "string" + } + } + }, + "additionalProperties": true + }, "UnitTestDefinition": { "type": "object", "title": "UnitTestDefinition", @@ -5804,9 +5836,6 @@ "model": { "type": "string" }, - "attached_node": { - "type": "string" - }, "given": { "type": "array", "items": { @@ -5839,6 +5868,9 @@ }, "depends_on": { "$ref": "#/$defs/DependsOn" + }, + "config": { + "$ref": "#/$defs/UnitTestConfig" } }, "additionalProperties": false, @@ -5851,7 +5883,6 @@ "unique_id", "fqn", "model", - "attached_node", "given", "expect" ] diff --git a/tests/unit/test_parser.py b/tests/unit/test_parser.py index 64bed3825ce..e2ecf3ad4d6 100644 --- a/tests/unit/test_parser.py +++ b/tests/unit/test_parser.py @@ -176,13 +176,14 @@ def file_block_for(self, data: str, filename: str, searched: str): return FileBlock(file=source_file) def assert_has_manifest_lengths( - self, manifest, macros=3, nodes=0, sources=0, docs=0, disabled=0 + self, manifest, macros=3, nodes=0, sources=0, docs=0, disabled=0, unit_tests=0 ): self.assertEqual(len(manifest.macros), macros) self.assertEqual(len(manifest.nodes), nodes) self.assertEqual(len(manifest.sources), sources) self.assertEqual(len(manifest.docs), docs) self.assertEqual(len(manifest.disabled), disabled) + self.assertEqual(len(manifest.unit_tests), unit_tests) def assertEqualNodes(node_one, node_two): @@ -371,8 +372,8 @@ def setUp(self): manifest=self.manifest, ) - def file_block_for(self, data, filename): - return super().file_block_for(data, filename, "models") + def file_block_for(self, data, filename, searched="models"): + return super().file_block_for(data, filename, searched) def yaml_block_for(self, test_yml: str, filename: str): file_block = self.file_block_for(data=test_yml, filename=filename) diff --git a/tests/unit/test_unit_test_parser.py b/tests/unit/test_unit_test_parser.py new file mode 100644 index 00000000000..31d98c18b8e --- /dev/null +++ b/tests/unit/test_unit_test_parser.py @@ -0,0 +1,171 @@ +from dbt.contracts.graph.nodes import UnitTestDefinition, UnitTestConfig, DependsOn, NodeType +from dbt.exceptions import ParsingError +from dbt.parser import SchemaParser +from dbt.parser.unit_tests import UnitTestParser + +from .utils import MockNode +from .test_parser import SchemaParserTest, assertEqualNodes + +from unittest import mock + + +UNIT_TEST_MODEL_NOT_FOUND_SOURCE = """ +unit: + - model: my_model_doesnt_exist + tests: + - name: test_my_model_doesnt_exist + description: "unit test description" + given: [] + expect: [] +""" + + +UNIT_TEST_SOURCE = """ +unit: + - model: my_model + tests: + - name: test_my_model + description: "unit test description" + given: [] + expect: [] +""" + + +UNIT_TEST_VERSIONED_MODEL_SOURCE = """ +unit: + - model: my_model_versioned.v1 + tests: + - name: test_my_model_versioned + description: "unit test description" + given: [] + expect: [] +""" + + +UNIT_TEST_CONFIG_SOURCE = """ +unit: + - model: my_model + tests: + - name: test_my_model + config: + tags: "schema_tag" + meta: + meta_key: meta_value + meta_jinja_key: '{{ 1 + 1 }}' + description: "unit test description" + given: [] + expect: [] +""" + + +UNIT_TEST_MULTIPLE_SOURCE = """ +unit: + - model: my_model + tests: + - name: test_my_model + description: "unit test description" + given: [] + expect: [] + - name: test_my_model2 + description: "unit test description" + given: [] + expect: [] +""" + + +class UnitTestParserTest(SchemaParserTest): + def setUp(self): + super().setUp() + my_model_node = MockNode( + package="snowplow", + name="my_model", + config=mock.MagicMock(enabled=True), + refs=[], + sources=[], + patch_path=None, + ) + self.manifest.nodes = {my_model_node.unique_id: my_model_node} + self.parser = SchemaParser( + project=self.snowplow_project_config, + manifest=self.manifest, + root_project=self.root_project_config, + ) + + def file_block_for(self, data, filename): + return super().file_block_for(data, filename, "unit") + + def test_basic_model_not_found(self): + block = self.yaml_block_for(UNIT_TEST_MODEL_NOT_FOUND_SOURCE, "test_my_model.yml") + + with self.assertRaises(ParsingError): + UnitTestParser(self.parser, block).parse() + + def test_basic(self): + block = self.yaml_block_for(UNIT_TEST_SOURCE, "test_my_model.yml") + + UnitTestParser(self.parser, block).parse() + + self.assert_has_manifest_lengths(self.parser.manifest, nodes=1, unit_tests=1) + unit_test = list(self.parser.manifest.unit_tests.values())[0] + expected = UnitTestDefinition( + name="test_my_model", + model="my_model", + resource_type=NodeType.Unit, + package_name="snowplow", + path=block.path.relative_path, + original_file_path=block.path.original_file_path, + unique_id="unit.snowplow.my_model.test_my_model", + given=[], + expect=[], + description="unit test description", + overrides=None, + depends_on=DependsOn(nodes=["model.snowplow.my_model"]), + fqn=["snowplow", "my_model", "test_my_model"], + config=UnitTestConfig(), + ) + assertEqualNodes(unit_test, expected) + + def test_unit_test_config(self): + block = self.yaml_block_for(UNIT_TEST_CONFIG_SOURCE, "test_my_model.yml") + self.root_project_config.unit_tests = { + "snowplow": {"my_model": {"+tags": ["project_tag"]}} + } + + UnitTestParser(self.parser, block).parse() + + self.assert_has_manifest_lengths(self.parser.manifest, nodes=1, unit_tests=1) + unit_test = self.parser.manifest.unit_tests["unit.snowplow.my_model.test_my_model"] + self.assertEqual(sorted(unit_test.config.tags), sorted(["schema_tag", "project_tag"])) + self.assertEqual(unit_test.config.meta, {"meta_key": "meta_value", "meta_jinja_key": "2"}) + + def test_unit_test_versioned_model(self): + block = self.yaml_block_for(UNIT_TEST_VERSIONED_MODEL_SOURCE, "test_my_model.yml") + my_model_versioned_node = MockNode( + package="snowplow", + name="my_model_versioned", + config=mock.MagicMock(enabled=True), + refs=[], + sources=[], + patch_path=None, + version=1, + ) + self.manifest.nodes[my_model_versioned_node.unique_id] = my_model_versioned_node + + UnitTestParser(self.parser, block).parse() + + self.assert_has_manifest_lengths(self.parser.manifest, nodes=2, unit_tests=1) + unit_test = self.parser.manifest.unit_tests[ + "unit.snowplow.my_model_versioned.v1.test_my_model_versioned" + ] + self.assertEqual(len(unit_test.depends_on.nodes), 1) + self.assertEqual(unit_test.depends_on.nodes[0], "model.snowplow.my_model_versioned.v1") + + def test_multiple_unit_tests(self): + block = self.yaml_block_for(UNIT_TEST_MULTIPLE_SOURCE, "test_my_model.yml") + + UnitTestParser(self.parser, block).parse() + + self.assert_has_manifest_lengths(self.parser.manifest, nodes=1, unit_tests=2) + for unit_test in self.parser.manifest.unit_tests.values(): + self.assertEqual(len(unit_test.depends_on.nodes), 1) + self.assertEqual(unit_test.depends_on.nodes[0], "model.snowplow.my_model") diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 0f5c12ebbfd..827991d49c5 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -336,7 +336,7 @@ def MockNode(package, name, resource_type=None, **kwargs): version = kwargs.get("version") search_name = name if version is None else f"{name}.v{version}" - unique_id = f"{str(resource_type)}.{package}.{name}" + unique_id = f"{str(resource_type)}.{package}.{search_name}" node = mock.MagicMock( __class__=cls, resource_type=resource_type,