Skip to content

Commit

Permalink
more defensive node.all_constraints access (dbt-labs#10508)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Aug 1, 2024
1 parent 014444d commit ff2726c
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 25 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240731-095152.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: fix all_constraints access, disabled node parsing of non-uniquely named resources
time: 2024-07-31T09:51:52.751135-04:00
custom:
Author: michelleark gshank
Issue: "10509"
3 changes: 2 additions & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
InjectedCTE,
ManifestNode,
ManifestSQLNode,
ModelNode,
SeedNode,
UnitTestDefinition,
UnitTestNode,
Expand Down Expand Up @@ -441,7 +442,7 @@ def _compile_code(
node.relation_name = relation_name

# Compile 'ref' and 'source' expressions in foreign key constraints
if node.resource_type == NodeType.Model:
if isinstance(node, ModelNode):
for constraint in node.all_constraints:
if constraint.type == ConstraintType.foreign_key and constraint.to:
constraint.to = self._compile_relation_for_foreign_key_constraint_to(
Expand Down
36 changes: 29 additions & 7 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,11 @@ def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, List[Any]]] = {}
self.populate(manifest)

def populate(self, manifest):
def populate(self, manifest: "Manifest"):
for node in list(chain.from_iterable(manifest.disabled.values())):
self.add_node(node)

def add_node(self, node):
def add_node(self, node: GraphMemberNode) -> None:
if node.search_name not in self.storage:
self.storage[node.search_name] = {}
if node.package_name not in self.storage[node.search_name]:
Expand All @@ -427,8 +427,12 @@ def add_node(self, node):
# This should return a list of disabled nodes. It's different from
# the other Lookup functions in that it returns full nodes, not just unique_ids
def find(
self, search_name, package: Optional[PackageName], version: Optional[NodeVersion] = None
):
self,
search_name,
package: Optional[PackageName],
version: Optional[NodeVersion] = None,
resource_types: Optional[List[NodeType]] = None,
) -> Optional[List[Any]]:
if version:
search_name = f"{search_name}.v{version}"

Expand All @@ -437,16 +441,29 @@ def find(

pkg_dct: Mapping[PackageName, List[Any]] = self.storage[search_name]

nodes = []
if package is None:
if not pkg_dct:
return None
else:
return next(iter(pkg_dct.values()))
nodes = next(iter(pkg_dct.values()))
elif package in pkg_dct:
return pkg_dct[package]
nodes = pkg_dct[package]
else:
return None

if resource_types is None:
return nodes
else:
new_nodes = []
for node in nodes:
if node.resource_type in resource_types:
new_nodes.append(node)
if not new_nodes:
return None
else:
return new_nodes


class AnalysisLookup(RefableLookup):
_lookup_types: ClassVar[set] = set([NodeType.Analysis])
Expand Down Expand Up @@ -1295,7 +1312,12 @@ def resolve_ref(

# it's possible that the node is disabled
if disabled is None:
disabled = self.disabled_lookup.find(target_model_name, pkg, target_model_version)
disabled = self.disabled_lookup.find(
target_model_name,
pkg,
version=target_model_version,
resource_types=REFABLE_NODE_TYPES,
)

if disabled:
return Disabled(disabled[0])
Expand Down
44 changes: 28 additions & 16 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,20 @@
from dbt_common.exceptions import DbtValidationError
from dbt_common.utils import deep_merge

schema_file_keys = (
"models",
"seeds",
"snapshots",
"sources",
"macros",
"analyses",
"exposures",
"metrics",
"semantic_models",
"saved_queries",
)
schema_file_keys_to_resource_types = {
"models": NodeType.Model,
"seeds": NodeType.Seed,
"snapshots": NodeType.Snapshot,
"sources": NodeType.Source,
"macros": NodeType.Macro,
"analyses": NodeType.Analysis,
"exposures": NodeType.Exposure,
"metrics": NodeType.Metric,
"semantic_models": NodeType.SemanticModel,
"saved_queries": NodeType.SavedQuery,
}

schema_file_keys = list(schema_file_keys_to_resource_types.keys())


# ===============================================================================
Expand Down Expand Up @@ -678,7 +680,10 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
# handle disabled nodes
if unique_id is None:
# Node might be disabled. Following call returns list of matching disabled nodes
found_nodes = self.manifest.disabled_lookup.find(patch.name, patch.package_name)
resource_type = schema_file_keys_to_resource_types[patch.yaml_key]
found_nodes = self.manifest.disabled_lookup.find(
patch.name, patch.package_name, resource_types=[resource_type]
)
if found_nodes:
if len(found_nodes) > 1 and patch.config.get("enabled"):
# There are multiple disabled nodes for this model and the schema file wants to enable one.
Expand Down Expand Up @@ -810,7 +815,9 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef)

if versioned_model_unique_id is None:
# Node might be disabled. Following call returns list of matching disabled nodes
found_nodes = self.manifest.disabled_lookup.find(versioned_model_name, None)
found_nodes = self.manifest.disabled_lookup.find(
versioned_model_name, None, resource_types=[NodeType.Model]
)
if found_nodes:
if len(found_nodes) > 1 and target.config.get("enabled"):
# There are multiple disabled nodes for this model and the schema file wants to enable one.
Expand Down Expand Up @@ -911,6 +918,11 @@ def _target_type(self) -> Type[UnparsedModelUpdate]:

def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None:
super().patch_node_properties(node, patch)

# Remaining patch properties are only relevant to ModelNode objects
if not isinstance(node, ModelNode):
return

node.version = patch.version
node.latest_version = patch.latest_version
node.deprecation_date = patch.deprecation_date
Expand All @@ -927,7 +939,7 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None:
self.patch_time_spine(node, patch.time_spine)
node.build_contract_checksum()

def patch_constraints(self, node, constraints: List[Dict[str, Any]]) -> None:
def patch_constraints(self, node: ModelNode, constraints: List[Dict[str, Any]]) -> None:
contract_config = node.config.get("contract")
if contract_config.enforced is True:
self._validate_constraint_prerequisites(node)
Expand Down Expand Up @@ -963,7 +975,7 @@ def _process_constraints_refs_and_sources(self, model_node: ModelNode) -> None:
else:
model_node.sources.append(ref_or_source)

def patch_time_spine(self, node, time_spine: Optional[TimeSpine]) -> None:
def patch_time_spine(self, node: ModelNode, time_spine: Optional[TimeSpine]) -> None:
node.time_spine = time_spine

def _validate_pk_constraints(
Expand Down
44 changes: 44 additions & 0 deletions tests/functional/configs/test_disabled_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,47 @@ def test_conditional_model(self, project):
assert len(results) == 2
results = run_dbt(["test"])
assert len(results) == 5


my_analysis_sql = """
{{
config(enabled=False)
}}
select 1 as id
"""


schema_yml = """
models:
- name: my_analysis
description: "A Sample model"
config:
meta:
owner: Joe
analyses:
- name: my_analysis
description: "A sample analysis"
config:
enabled: false
"""


class TestDisabledConfigsSameName:
@pytest.fixture(scope="class")
def models(self):
return {
"my_analysis.sql": my_analysis_sql,
"schema.yml": schema_yml,
}

@pytest.fixture(scope="class")
def analyses(self):
return {
"my_analysis.sql": my_analysis_sql,
}

def test_disabled_analysis(self, project):
manifest = run_dbt(["parse"])
assert len(manifest.disabled) == 2
assert len(manifest.nodes) == 0
125 changes: 124 additions & 1 deletion tests/unit/contracts/graph/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
WhereFilterIntersection,
)
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import Manifest, ManifestMetadata
from dbt.contracts.graph.manifest import DisabledLookup, Manifest, ManifestMetadata
from dbt.contracts.graph.nodes import (
DependsOn,
Exposure,
Expand Down Expand Up @@ -2013,3 +2013,126 @@ def test_find_node_from_ref_or_source_invalid_expression(
):
with pytest.raises(ParsingError):
mock_manifest.find_node_from_ref_or_source(invalid_expression)


class TestDisabledLookup:
@pytest.fixture(scope="class")
def manifest(self):
return Manifest(
nodes={},
sources={},
macros={},
docs={},
disabled={},
files={},
exposures={},
selectors={},
)

@pytest.fixture(scope="class")
def mock_model(self):
return MockNode("package", "name", NodeType.Model)

@pytest.fixture(scope="class")
def mock_model_with_version(self):
return MockNode("package", "name", NodeType.Model, version=3)

@pytest.fixture(scope="class")
def mock_seed(self):
return MockNode("package", "name", NodeType.Seed)

def test_find(self, manifest, mock_model):
manifest.disabled = {"model.package.name": [mock_model]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package") == [mock_model]

def test_find_wrong_name(self, manifest, mock_model):
manifest.disabled = {"model.package.name": [mock_model]}
lookup = DisabledLookup(manifest)

assert lookup.find("missing_name", "package") is None

def test_find_wrong_package(self, manifest, mock_model):
manifest.disabled = {"model.package.name": [mock_model]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "missing_package") is None

def test_find_wrong_version(self, manifest, mock_model):
manifest.disabled = {"model.package.name": [mock_model]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", version=3) is None

def test_find_wrong_resource_types(self, manifest, mock_model):
manifest.disabled = {"model.package.name": [mock_model]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", resource_types=[NodeType.Analysis]) is None

def test_find_no_package(self, manifest, mock_model):
manifest.disabled = {"model.package.name": [mock_model]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", None) == [mock_model]

def test_find_versioned_node(self, manifest, mock_model_with_version):
manifest.disabled = {"model.package.name": [mock_model_with_version]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", version=3) == [mock_model_with_version]

def test_find_versioned_node_no_package(self, manifest, mock_model_with_version):
manifest.disabled = {"model.package.name": [mock_model_with_version]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", None, version=3) == [mock_model_with_version]

def test_find_versioned_node_no_version(self, manifest, mock_model_with_version):
manifest.disabled = {"model.package.name": [mock_model_with_version]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package") is None

def test_find_versioned_node_wrong_version(self, manifest, mock_model_with_version):
manifest.disabled = {"model.package.name": [mock_model_with_version]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", version=2) is None

def test_find_versioned_node_wrong_name(self, manifest, mock_model_with_version):
manifest.disabled = {"model.package.name": [mock_model_with_version]}
lookup = DisabledLookup(manifest)

assert lookup.find("wrong_name", "package", version=3) is None

def test_find_versioned_node_wrong_package(self, manifest, mock_model_with_version):
manifest.disabled = {"model.package.name": [mock_model_with_version]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "wrong_package", version=3) is None

def test_find_multiple_nodes(self, manifest, mock_model, mock_seed):
manifest.disabled = {"model.package.name": [mock_model, mock_seed]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package") == [mock_model, mock_seed]

def test_find_multiple_nodes_with_resource_types(self, manifest, mock_model, mock_seed):
manifest.disabled = {"model.package.name": [mock_model, mock_seed]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", resource_types=[NodeType.Model]) == [mock_model]

def test_find_multiple_nodes_with_wrong_resource_types(self, manifest, mock_model, mock_seed):
manifest.disabled = {"model.package.name": [mock_model, mock_seed]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", resource_types=[NodeType.Analysis]) is None

def test_find_multiple_nodes_with_resource_types_empty(self, manifest, mock_model, mock_seed):
manifest.disabled = {"model.package.name": [mock_model, mock_seed]}
lookup = DisabledLookup(manifest)

assert lookup.find("name", "package", resource_types=[]) is None

0 comments on commit ff2726c

Please sign in to comment.