diff --git a/.changes/unreleased/Features-20240719-161841.yaml b/.changes/unreleased/Features-20240719-161841.yaml new file mode 100644 index 00000000000..a84a9d45e9d --- /dev/null +++ b/.changes/unreleased/Features-20240719-161841.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support ref and source in foreign key constraint expressions +time: 2024-07-19T16:18:41.434278-04:00 +custom: + Author: michelleark + Issue: "8062" diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index 8e0c34df2e6..d8746a7607d 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -1,11 +1,13 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union import jinja2 -from dbt.exceptions import MacroNamespaceNotStringError +from dbt.artifacts.resources import RefArgs +from dbt.exceptions import MacroNamespaceNotStringError, ParsingError from dbt_common.clients.jinja import get_environment from dbt_common.exceptions.macros import MacroNameNotStringError from dbt_common.tests import test_caching_enabled +from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore _TESTING_MACRO_CACHE: Optional[Dict[str, Any]] = {} @@ -153,3 +155,39 @@ def statically_parse_adapter_dispatch(func_call, ctx, db_wrapper): possible_macro_calls.append(f"{package_name}.{func_name}") return possible_macro_calls + + +def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]: + """ + Returns a RefArgs or List[str] object, corresponding to ref or source respectively, given an input jinja expression. + + input: str representing how input node is referenced in tested model sql + * examples: + - "ref('my_model_a')" + - "ref('my_model_a', version=3)" + - "ref('package', 'my_model_a', version=3)" + - "source('my_source_schema', 'my_source_name')" + + If input is not a well-formed jinja ref or source expression, a ParsingError is raised. + """ + ref_or_source: Union[RefArgs, List[str]] + + try: + statically_parsed = py_extract_from_source(f"{{{{ {expression} }}}}") + except ExtractionError: + raise ParsingError(f"Invalid jinja expression: {expression}") + + if statically_parsed.get("refs"): + raw_ref = list(statically_parsed["refs"])[0] + ref_or_source = RefArgs( + package=raw_ref.get("package"), + name=raw_ref.get("name"), + version=raw_ref.get("version"), + ) + elif statically_parsed.get("sources"): + source_name, source_table_name = list(statically_parsed["sources"])[0] + ref_or_source = [source_name, source_table_name] + else: + raise ParsingError(f"Invalid ref or source expression: {expression}") + + return ref_or_source diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index d03407b2a4c..47d7ffbdb51 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -29,12 +29,15 @@ from dbt.exceptions import ( DbtInternalError, DbtRuntimeError, + ForeignKeyConstraintToSyntaxError, GraphDependencyNotFoundError, + ParsingError, ) from dbt.flags import get_flags from dbt.graph import Graph from dbt.node_types import ModelLanguage, NodeType from dbt_common.clients.system import make_directory +from dbt_common.contracts.constraints import ConstraintType from dbt_common.events.contextvars import get_node_info from dbt_common.events.format import pluralize from dbt_common.events.functions import fire_event @@ -437,8 +440,31 @@ def _compile_code( relation_name = str(relation_cls.create_from(self.config, node)) node.relation_name = relation_name + # Compile 'ref' and 'source' expressions in foreign key constraints + if node.resource_type == NodeType.Model: + for constraint in node.all_constraints: + if constraint.type == ConstraintType.foreign_key and constraint.to: + constraint.to = self._compile_relation_for_foreign_key_constraint_to( + manifest, node, constraint.to + ) + return node + def _compile_relation_for_foreign_key_constraint_to( + self, manifest: Manifest, node: ManifestSQLNode, to_expression: str + ) -> str: + try: + foreign_key_node = manifest.find_node_from_ref_or_source(to_expression) + except ParsingError: + raise ForeignKeyConstraintToSyntaxError(node, to_expression) + + if not foreign_key_node: + raise GraphDependencyNotFoundError(node, to_expression) + + adapter = get_adapter(self.config) + relation_name = str(adapter.Relation.create_from(self.config, foreign_key_node)) + return relation_name + # This method doesn't actually "compile" any of the nodes. That is done by the # "compile_node" method. This creates a Linker and builds the networkx graph, # writes out the graph.gpickle file, and prints the stats, returning a Graph object. diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 358ef84db63..21c5571b74b 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -32,9 +32,10 @@ from dbt.adapters.factory import get_adapter_package_names # to preserve import paths -from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion +from dbt.artifacts.resources import BaseResource, DeferRelation, NodeVersion, RefArgs from dbt.artifacts.resources.v1.config import NodeConfig from dbt.artifacts.schemas.manifest import ManifestMetadata, UniqueID, WritableManifest +from dbt.clients.jinja_static import statically_parse_ref_or_source from dbt.contracts.files import ( AnySourceFile, FileHash, @@ -1635,6 +1636,22 @@ def add_saved_query(self, source_file: SchemaSourceFile, saved_query: SavedQuery # end of methods formerly in ParseResult + def find_node_from_ref_or_source( + self, expression: str + ) -> Optional[Union[ModelNode, SourceDefinition]]: + ref_or_source = statically_parse_ref_or_source(expression) + + node = None + if isinstance(ref_or_source, RefArgs): + node = self.ref_lookup.find( + ref_or_source.name, ref_or_source.package, ref_or_source.version, self + ) + else: + source_name, source_table_name = ref_or_source[0], ref_or_source[1] + node = self.source_lookup.find(f"{source_name}.{source_table_name}", None, self) + + return node + # Provide support for copy.deepcopy() - we just need to avoid the lock! # pickle and deepcopy use this. It returns a callable object used to # create the initial version of the object and a tuple of arguments diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index da42fb7d766..42d19e2c8dd 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -85,7 +85,11 @@ NodeType, ) from dbt_common.clients.system import write_file -from dbt_common.contracts.constraints import ConstraintType +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) from dbt_common.events.contextvars import set_log_contextvars from dbt_common.events.functions import warn_or_error @@ -489,6 +493,18 @@ def search_name(self): def materialization_enforces_constraints(self) -> bool: return self.config.materialized in ["table", "incremental"] + @property + def all_constraints(self) -> List[Union[ModelLevelConstraint, ColumnLevelConstraint]]: + constraints: List[Union[ModelLevelConstraint, ColumnLevelConstraint]] = [] + for model_level_constraint in self.constraints: + constraints.append(model_level_constraint) + + for column in self.columns.values(): + for column_level_constraint in column.constraints: + constraints.append(column_level_constraint) + + return constraints + def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]: """ Infers the columns that can be used as primary key of a model in the following order: diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index aec2b5e3826..27aa863fd17 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -136,6 +136,18 @@ def get_message(self) -> str: return msg +class ForeignKeyConstraintToSyntaxError(CompilationError): + def __init__(self, node, expression: str) -> None: + self.expression = expression + self.node = node + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + msg = f"'{self.node.unique_id}' defines a foreign key constraint 'to' expression which is not valid 'ref' or 'source' syntax: {self.expression}." + + return msg + + # client level exceptions diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index af912b455cb..5e269fd385c 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -5,6 +5,8 @@ from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Type, TypeVar from dbt import deprecations +from dbt.artifacts.resources import RefArgs +from dbt.clients.jinja_static import statically_parse_ref_or_source from dbt.clients.yaml_helper import load_yaml_text from dbt.config import RuntimeConfig from dbt.context.configured import SchemaYamlVars, generate_schema_yml_context @@ -915,7 +917,7 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None: self.patch_constraints(node, patch.constraints) node.build_contract_checksum() - def patch_constraints(self, node, constraints) -> None: + def patch_constraints(self, node, constraints: List[Dict[str, Any]]) -> None: contract_config = node.config.get("contract") if contract_config.enforced is True: self._validate_constraint_prerequisites(node) @@ -930,6 +932,26 @@ def patch_constraints(self, node, constraints) -> None: self._validate_pk_constraints(node, constraints) node.constraints = [ModelLevelConstraint.from_dict(c) for c in constraints] + self._process_constraints_refs_and_sources(node) + + def _process_constraints_refs_and_sources(self, model_node: ModelNode) -> None: + """ + Populate model_node.refs and model_node.sources based on foreign-key constraint references, + whether defined at the model-level or column-level. + """ + for constraint in model_node.all_constraints: + if constraint.type == ConstraintType.foreign_key and constraint.to: + try: + ref_or_source = statically_parse_ref_or_source(constraint.to) + except ParsingError: + raise ParsingError( + f"Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model {model_node.name}: {constraint.to}." + ) + + if isinstance(ref_or_source, RefArgs): + model_node.refs.append(ref_or_source) + else: + model_node.sources.append(ref_or_source) def _validate_pk_constraints( self, model_node: ModelNode, constraints: List[Dict[str, Any]] diff --git a/tests/functional/constraints/fixtures.py b/tests/functional/constraints/fixtures.py new file mode 100644 index 00000000000..de60963bfec --- /dev/null +++ b/tests/functional/constraints/fixtures.py @@ -0,0 +1,115 @@ +model_foreign_key_model_schema_yml = """ +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: ref('my_model_to') + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_source_schema_yml = """ +sources: + - name: test_source + tables: + - name: test_table + +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: source('test_source', 'test_table') + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_model_node_not_found_schema_yml = """ +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: ref('doesnt_exist') + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_model_invalid_syntax_schema_yml = """ +models: + - name: my_model + constraints: + - type: foreign_key + columns: [id] + to: invalid + to_columns: [id] + columns: + - name: id + data_type: integer +""" + + +model_foreign_key_model_column_schema_yml = """ +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: ref('my_model_to') + to_columns: [id] +""" + + +model_foreign_key_column_invalid_syntax_schema_yml = """ +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: invalid + to_columns: [id] +""" + + +model_foreign_key_column_node_not_found_schema_yml = """ +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: ref('doesnt_exist') + to_columns: [id] +""" + +model_column_level_foreign_key_source_schema_yml = """ +sources: + - name: test_source + tables: + - name: test_table + +models: + - name: my_model + columns: + - name: id + data_type: integer + constraints: + - type: foreign_key + to: source('test_source', 'test_table') + to_columns: [id] +""" diff --git a/tests/functional/constraints/test_foreign_key_constraints.py b/tests/functional/constraints/test_foreign_key_constraints.py new file mode 100644 index 00000000000..2c02cfe7ad7 --- /dev/null +++ b/tests/functional/constraints/test_foreign_key_constraints.py @@ -0,0 +1,241 @@ +import pytest + +from dbt.artifacts.resources import RefArgs +from dbt.exceptions import CompilationError, ParsingError +from dbt.tests.util import get_artifact, run_dbt +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) +from tests.functional.constraints.fixtures import ( + model_column_level_foreign_key_source_schema_yml, + model_foreign_key_column_invalid_syntax_schema_yml, + model_foreign_key_column_node_not_found_schema_yml, + model_foreign_key_model_column_schema_yml, + model_foreign_key_model_invalid_syntax_schema_yml, + model_foreign_key_model_node_not_found_schema_yml, + model_foreign_key_model_schema_yml, + model_foreign_key_source_schema_yml, +) + + +class TestModelLevelForeignKeyConstraintToRef: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.constraints) == 1 + + parsed_constraint = node_with_fk_constraint.constraints[0] + assert parsed_constraint == ModelLevelConstraint( + type=ConstraintType.foreign_key, + columns=["id"], + to="ref('my_model_to')", + to_columns=["id"], + ) + # Assert column-level constraint source included in node.depends_on + assert node_with_fk_constraint.refs == [RefArgs("my_model_to")] + assert node_with_fk_constraint.depends_on.nodes == ["model.test.my_model_to"] + assert node_with_fk_constraint.sources == [] + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["constraints"][0] + assert compiled_constraint["to"] == f'"dbt"."{unique_schema}"."my_model_to"' + # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["columns"] == parsed_constraint.columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestModelLevelForeignKeyConstraintToSource: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_source_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.constraints) == 1 + + parsed_constraint = node_with_fk_constraint.constraints[0] + assert parsed_constraint == ModelLevelConstraint( + type=ConstraintType.foreign_key, + columns=["id"], + to="source('test_source', 'test_table')", + to_columns=["id"], + ) + # Assert column-level constraint source included in node.depends_on + assert node_with_fk_constraint.refs == [] + assert node_with_fk_constraint.depends_on.nodes == ["source.test.test_source.test_table"] + assert node_with_fk_constraint.sources == [["test_source", "test_table"]] + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["constraints"][0] + assert compiled_constraint["to"] == '"dbt"."test_source"."test_table"' + # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["columns"] == parsed_constraint.columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestModelLevelForeignKeyConstraintRefNotFoundError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_node_not_found_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to_doesnt_exist(self, project): + with pytest.raises( + CompilationError, match="depends on a node named 'doesnt_exist' which was not found" + ): + run_dbt(["parse"]) + + +class TestModelLevelForeignKeyConstraintRefSyntaxError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_invalid_syntax_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project): + with pytest.raises( + ParsingError, + match="Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model my_model: invalid", + ): + run_dbt(["parse"]) + + +class TestColumnLevelForeignKeyConstraintToRef: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_model_column_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_column_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.columns["id"].constraints) == 1 + + parsed_constraint = node_with_fk_constraint.columns["id"].constraints[0] + # Assert column-level constraint parsed + assert parsed_constraint == ColumnLevelConstraint( + type=ConstraintType.foreign_key, to="ref('my_model_to')", to_columns=["id"] + ) + # Assert column-level constraint ref included in node.depends_on + assert node_with_fk_constraint.refs == [RefArgs(name="my_model_to")] + assert node_with_fk_constraint.sources == [] + assert node_with_fk_constraint.depends_on.nodes == ["model.test.my_model_to"] + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["columns"]["id"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["columns"]["id"][ + "constraints" + ][0] + assert compiled_constraint["to"] == f'"dbt"."{unique_schema}"."my_model_to"' + # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestColumnLevelForeignKeyConstraintToSource: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_column_level_foreign_key_source_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project, unique_schema): + manifest = run_dbt(["parse"]) + node_with_fk_constraint = manifest.nodes["model.test.my_model"] + assert len(node_with_fk_constraint.columns["id"].constraints) == 1 + + parsed_constraint = node_with_fk_constraint.columns["id"].constraints[0] + assert parsed_constraint == ColumnLevelConstraint( + type=ConstraintType.foreign_key, + to="source('test_source', 'test_table')", + to_columns=["id"], + ) + # Assert column-level constraint source included in node.depends_on + assert node_with_fk_constraint.refs == [] + assert node_with_fk_constraint.depends_on.nodes == ["source.test.test_source.test_table"] + assert node_with_fk_constraint.sources == [["test_source", "test_table"]] + + # Assert compilation renders to from 'ref' to relation identifer + run_dbt(["compile"]) + manifest = get_artifact(project.project_root, "target", "manifest.json") + assert len(manifest["nodes"]["model.test.my_model"]["columns"]["id"]["constraints"]) == 1 + + compiled_constraint = manifest["nodes"]["model.test.my_model"]["columns"]["id"][ + "constraints" + ][0] + assert compiled_constraint["to"] == '"dbt"."test_source"."test_table"' + # # Other constraint fields should remain as parsed + assert compiled_constraint["to_columns"] == parsed_constraint.to_columns + assert compiled_constraint["type"] == parsed_constraint.type + + +class TestColumnLevelForeignKeyConstraintRefNotFoundError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_column_node_not_found_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to_doesnt_exist(self, project): + with pytest.raises( + CompilationError, match="depends on a node named 'doesnt_exist' which was not found" + ): + run_dbt(["parse"]) + + +class TestColumnLevelForeignKeyConstraintRefSyntaxError: + @pytest.fixture(scope="class") + def models(self): + return { + "constraints_schema.yml": model_foreign_key_column_invalid_syntax_schema_yml, + "my_model.sql": "select 1 as id", + "my_model_to.sql": "select 1 as id", + } + + def test_model_level_fk_to(self, project): + with pytest.raises( + ParsingError, + match="Invalid 'ref' or 'source' syntax on foreign key constraint 'to' on model my_model: invalid.", + ): + run_dbt(["parse"]) diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index d575cfb76e8..171976a6b50 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -1,44 +1,79 @@ -import unittest +import pytest -from dbt.clients.jinja_static import statically_extract_macro_calls +from dbt.artifacts.resources import RefArgs +from dbt.clients.jinja_static import ( + statically_extract_macro_calls, + statically_parse_ref_or_source, +) from dbt.context.base import generate_base_context +from dbt.exceptions import ParsingError -class MacroCalls(unittest.TestCase): - def setUp(self): - self.macro_strings = [ +@pytest.mark.parametrize( + "macro_string,expected_possible_macro_calls", + [ + ( "{% macro parent_macro() %} {% do return(nested_macro()) %} {% endmacro %}", - "{% macro lr_macro() %} {{ return(load_result('relations').table) }} {% endmacro %}", - "{% macro get_snapshot_unique_id() -%} {{ return(adapter.dispatch('get_snapshot_unique_id')()) }} {%- endmacro %}", - "{% macro get_columns_in_query(select_sql) -%} {{ return(adapter.dispatch('get_columns_in_query')(select_sql)) }} {% endmacro %}", - """{% macro test_mutually_exclusive_ranges(model) %} - with base as ( - select {{ get_snapshot_unique_id() }} as dbt_unique_id, - * - from {{ model }} ) - {% endmacro %}""", - "{% macro test_my_test(model) %} select {{ current_timestamp_backcompat() }} {% endmacro %}", - "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind4', 'foo_utils4')) }} {%- endmacro %}", - "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind5', macro_namespace = 'foo_utils5')) }} {%- endmacro %}", - ] - - self.possible_macro_calls = [ ["nested_macro"], + ), + ( + "{% macro lr_macro() %} {{ return(load_result('relations').table) }} {% endmacro %}", ["load_result"], + ), + ( + "{% macro get_snapshot_unique_id() -%} {{ return(adapter.dispatch('get_snapshot_unique_id')()) }} {%- endmacro %}", ["get_snapshot_unique_id"], + ), + ( + "{% macro get_columns_in_query(select_sql) -%} {{ return(adapter.dispatch('get_columns_in_query')(select_sql)) }} {% endmacro %}", ["get_columns_in_query"], + ), + ( + """{% macro test_mutually_exclusive_ranges(model) %} + with base as ( + select {{ get_snapshot_unique_id() }} as dbt_unique_id, + * + from {{ model }} ) + {% endmacro %}""", ["get_snapshot_unique_id"], + ), + ( + "{% macro test_my_test(model) %} select {{ current_timestamp_backcompat() }} {% endmacro %}", ["current_timestamp_backcompat"], + ), + ( + "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind4', 'foo_utils4')) }} {%- endmacro %}", ["test_some_kind4", "foo_utils4.test_some_kind4"], + ), + ( + "{% macro some_test(model) -%} {{ return(adapter.dispatch('test_some_kind5', macro_namespace = 'foo_utils5')) }} {%- endmacro %}", ["test_some_kind5", "foo_utils5.test_some_kind5"], - ] + ), + ], +) +def test_extract_macro_calls(macro_string, expected_possible_macro_calls): + cli_vars = {"local_utils_dispatch_list": ["foo_utils4"]} + ctx = generate_base_context(cli_vars) + + possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) + assert possible_macro_calls == expected_possible_macro_calls + - def test_macro_calls(self): - cli_vars = {"local_utils_dispatch_list": ["foo_utils4"]} - ctx = generate_base_context(cli_vars) +class TestStaticallyParseRefOrSource: + def test_invalid_expression(self): + with pytest.raises(ParsingError): + statically_parse_ref_or_source("invalid") - index = 0 - for macro_string in self.macro_strings: - possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) - self.assertEqual(self.possible_macro_calls[index], possible_macro_calls) - index += 1 + @pytest.mark.parametrize( + "expression,expected_ref_or_source", + [ + ("ref('model')", RefArgs(name="model")), + ("ref('package','model')", RefArgs(name="model", package="package")), + ("ref('model',v=3)", RefArgs(name="model", version=3)), + ("ref('package','model',v=3)", RefArgs(name="model", package="package", version=3)), + ("source('schema', 'table')", ["schema", "table"]), + ], + ) + def test_valid_ref_expression(self, expression, expected_ref_or_source): + ref_or_source = statically_parse_ref_or_source(expression) + assert ref_or_source == expected_ref_or_source diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 35e96308da7..dc81fa4b7dc 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -37,7 +37,7 @@ SeedNode, SourceDefinition, ) -from dbt.exceptions import AmbiguousResourceNameRefError +from dbt.exceptions import AmbiguousResourceNameRefError, ParsingError from dbt.flags import set_from_args from dbt.node_types import NodeType from dbt_common.events.functions import reset_metadata_vars @@ -1962,3 +1962,53 @@ def test_resolve_doc(docs, package, expected): expected_package, expected_name = expected assert result.name == expected_name assert result.package_name == expected_package + + +class TestManifestFindNodeFromRefOrSource: + @pytest.fixture + def mock_node(self): + return MockNode("my_package", "my_model") + + @pytest.fixture + def mock_disabled_node(self): + return MockNode("my_package", "disabled_node", config={"enabled": False}) + + @pytest.fixture + def mock_source(self): + return MockSource("root", "my_source", "source_table") + + @pytest.fixture + def mock_disabled_source(self): + return MockSource("root", "my_source", "disabled_source_table", config={"enabled": False}) + + @pytest.fixture + def mock_manifest(self, mock_node, mock_source, mock_disabled_node, mock_disabled_source): + return make_manifest( + nodes=[mock_node, mock_disabled_node], sources=[mock_source, mock_disabled_source] + ) + + @pytest.mark.parametrize( + "expression,expected_node", + [ + ("ref('my_package', 'my_model')", "mock_node"), + ("ref('my_package', 'doesnt_exist')", None), + ("ref('my_package', 'disabled_node')", "mock_disabled_node"), + ("source('my_source', 'source_table')", "mock_source"), + ("source('my_source', 'doesnt_exist')", None), + ("source('my_source', 'disabled_source_table')", "mock_disabled_source"), + ], + ) + def test_find_node_from_ref_or_source(self, expression, expected_node, mock_manifest, request): + node = mock_manifest.find_node_from_ref_or_source(expression) + + if expected_node is None: + assert node is None + else: + assert node == request.getfixturevalue(expected_node) + + @pytest.mark.parametrize("invalid_expression", ["invalid", "ref(')"]) + def test_find_node_from_ref_or_source_invalid_expression( + self, invalid_expression, mock_manifest + ): + with pytest.raises(ParsingError): + mock_manifest.find_node_from_ref_or_source(invalid_expression) diff --git a/tests/unit/graph/test_nodes.py b/tests/unit/graph/test_nodes.py index ff14874eb65..79522d06427 100644 --- a/tests/unit/graph/test_nodes.py +++ b/tests/unit/graph/test_nodes.py @@ -68,6 +68,48 @@ def test_is_past_deprecation_date( assert default_model_node.is_past_deprecation_date is expected_is_past_deprecation_date + @pytest.mark.parametrize( + "model_constraints,columns,expected_all_constraints", + [ + ([], {}, []), + ( + [ModelLevelConstraint(type=ConstraintType.foreign_key)], + {}, + [ModelLevelConstraint(type=ConstraintType.foreign_key)], + ), + ( + [], + { + "id": ColumnInfo( + name="id", + constraints=[ColumnLevelConstraint(type=ConstraintType.foreign_key)], + ) + }, + [ColumnLevelConstraint(type=ConstraintType.foreign_key)], + ), + ( + [ModelLevelConstraint(type=ConstraintType.foreign_key)], + { + "id": ColumnInfo( + name="id", + constraints=[ColumnLevelConstraint(type=ConstraintType.foreign_key)], + ) + }, + [ + ModelLevelConstraint(type=ConstraintType.foreign_key), + ColumnLevelConstraint(type=ConstraintType.foreign_key), + ], + ), + ], + ) + def test_all_constraints( + self, default_model_node, model_constraints, columns, expected_all_constraints + ): + default_model_node.constraints = model_constraints + default_model_node.columns = columns + + assert default_model_node.all_constraints == expected_all_constraints + class TestSemanticModel: @pytest.fixture(scope="function")