diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 0e60c90274b..72af226d5d7 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -108,7 +108,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): # unit_test_node now has a populated refs/sources self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node - # Now create input_nodes for the test inputs """ given: @@ -130,7 +129,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): given.input, tested_node, test_case.name ) input_name = original_input_node.name - common_fields = { "resource_type": NodeType.Model, # root directory for input and output fixtures @@ -147,21 +145,25 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): "name": input_name, "path": f"{input_name}.sql", } + resource_type = original_input_node.resource_type - if original_input_node.resource_type in ( + if resource_type in ( NodeType.Model, NodeType.Seed, NodeType.Snapshot, ): + input_node = ModelNode( **common_fields, defer_relation=original_input_node.defer_relation, ) - if ( - original_input_node.resource_type == NodeType.Model - and original_input_node.version - ): + + if not resource_type == NodeType.Model: + continue + if original_input_node.version: input_node.version = original_input_node.version + if original_input_node.latest_version: + input_node.latest_version = original_input_node.latest_version elif original_input_node.resource_type == NodeType.Source: # We are reusing the database/schema/identifier from the original source, diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 160f528787d..514e712a725 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -291,22 +291,139 @@ def test_basic(self, project): assert len(results) == 2 -class TestUnitTestIncrementalModelWithVersion: +schema_ref_with_version = """ +models: + - name: source + latest_version: {latest_version} + versions: + - v: 1 + - v: 2 + - name: model_to_test +unit_tests: + - name: ref_versioned + model: 'model_to_test' + given: + - input: {input} + rows: + - {{result: 3}} + expect: + rows: + - {{result: 3}} + +""" + + +class TestUnitTestRefWithVersion: @pytest.fixture(scope="class") def models(self): return { - "my_incremental_model.sql": my_incremental_model_sql, - "events.sql": event_sql, - "schema.yml": my_incremental_model_versioned_yml + test_my_model_incremental_yml_basic, + "model_to_test.sql": "select result from {{ ref('source')}}", + "source.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_version.format( + **{"latest_version": 1, "input": "ref('source')"} + ), } def test_basic(self, project): results = run_dbt(["run"]) - assert len(results) == 2 - # Select by model name - results = run_dbt(["test", "--select", "my_incremental_model"], expect_pass=True) - assert len(results) == 2 + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefMissingVersionModel: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source')}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_version.format( + **{"latest_version": 1, "input": "ref('source', v=1)"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefWithMissingVersionRef: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source', v=1)}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_version.format( + **{"latest_version": 1, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefWithVersionLatestSecond: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source')}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_version.format( + **{"latest_version": 2, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefWithVersionMissingRefTest: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source', v=2)}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_version.format( + **{"latest_version": 1, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + # TODO: How to capture an compilation Error? pytest.raises(CompilationError) not working + run_dbt(["test", "--select", "model_to_test"], expect_pass=False) + + +class TestUnitTestRefWithVersionDiffLatest: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source', v=2)}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_version.format( + **{"latest_version": 1, "input": "ref('source', v=2)"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + run_dbt(["test", "--select", "model_to_test"], expect_pass=True) class TestUnitTestExplicitSeed: