Skip to content

Commit

Permalink
Move deferral resolution from merge_from_artifact to `RuntimeRefRes…
Browse files Browse the repository at this point in the history
…olver` (#9199)

* Move deferral from task to manifest loading + RefResolver

* dbt clone must specify --defer

* Fix deferral for unit test type deteection

* Add changelog

* Move merge_from_artifact from end of parsing back to task before_run to reduce scope of refactor

* PR review. DeferRelation conforms to RelationConfig protocol

* Add test case for #10017

* Update manifest v12 in test_previous_version_state

---------

Co-authored-by: Michelle Ark <[email protected]>
  • Loading branch information
jtcohen6 and MichelleArk authored May 2, 2024
1 parent f884eb4 commit 487a532
Show file tree
Hide file tree
Showing 25 changed files with 130 additions and 122 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20240201-003033.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Split up deferral across parsing (adding 'defer_relation' from state manifest)
and runtime ref resolution"
time: 2024-02-01T00:30:33.573665+01:00
custom:
Author: jtcohen6
Issue: "9199"
10 changes: 9 additions & 1 deletion core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from dataclasses import dataclass, field
from dbt.artifacts.resources.base import GraphResource, FileHash, Docs
from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.resources.v1.config import NodeConfig
from dbt_common.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin
from dbt_common.contracts.config.properties import AdditionalPropertiesMixin
Expand Down Expand Up @@ -154,6 +155,14 @@ def quoting_dict(self) -> Dict[str, bool]:
class DeferRelation(HasRelationMetadata):
alias: str
relation_name: Optional[str]
# The rest of these fields match RelationConfig protocol exactly
resource_type: NodeType
name: str
description: str
compiled_code: Optional[str]
meta: Dict[str, Any]
tags: List[str]
config: Optional[NodeConfig]

@property
def identifier(self):
Expand Down Expand Up @@ -181,7 +190,6 @@ class ParsedResource(ParsedResourceMandatory):
docs: Docs = field(default_factory=Docs)
patch_path: Optional[str] = None
build_path: Optional[str] = None
deferred: bool = False
unrendered_config: Dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=lambda: time.time())
config_call_dict: Dict[str, Any] = field(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def wrapper(*args, **kwargs):

runtime_config = ctx.obj["runtime_config"]

# a manifest has already been set on the context, so don't overwrite it
# if a manifest has already been set on the context, don't overwrite it
if ctx.obj.get("manifest") is None:
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
Expand Down
21 changes: 20 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def resolve(
self.model.package_name,
)

# Raise an error if the reference target is missing
if target_model is None or isinstance(target_model, Disabled):
raise TargetNotFoundError(
node=self.model,
Expand All @@ -513,6 +514,8 @@ def resolve(
target_version=target_version,
disabled=isinstance(target_model, Disabled),
)

# Raise error if trying to reference a 'private' resource outside its 'group'
elif self.manifest.is_invalid_private_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -522,6 +525,7 @@ def resolve(
access=AccessType.Private,
scope=cast_to_str(target_model.group),
)
# Or a 'protected' resource outside its project/package namespace
elif self.manifest.is_invalid_protected_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -531,14 +535,29 @@ def resolve(
access=AccessType.Protected,
scope=target_model.package_name,
)

self.validate(target_model, target_name, target_package, target_version)
return self.create_relation(target_model)

def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from(target_model, limit=self.resolve_limit)
elif (
hasattr(target_model, "defer_relation")
and target_model.defer_relation
and self.config.args.defer
and (
# User has explicitly opted to prefer defer_relation
self.config.args.favor_state
# Or, this node's relation does not exist in the expected target location (cache lookup)
or not get_adapter(self.config).get_relation(
target_model.database, target_model.schema, target_model.identifier
)
)
):
return self.Relation.create_from(
self.config, target_model.defer_relation, limit=self.resolve_limit
)
else:
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)

Expand Down
57 changes: 22 additions & 35 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field, replace
from itertools import chain, islice
from itertools import chain
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from typing import (
Expand All @@ -18,7 +18,6 @@
TypeVar,
Callable,
Generic,
AbstractSet,
ClassVar,
)
from typing_extensions import Protocol
Expand All @@ -39,6 +38,7 @@
ResultNode,
SavedQuery,
SemanticModel,
SeedNode,
SourceDefinition,
UnpatchedSourceDefinition,
UnitTestDefinition,
Expand All @@ -54,6 +54,7 @@
DeferRelation,
BaseResource,
)
from dbt.artifacts.resources.v1.config import NodeConfig
from dbt.artifacts.schemas.manifest import WritableManifest, ManifestMetadata, UniqueID
from dbt.contracts.files import (
SourceFile,
Expand All @@ -74,7 +75,7 @@
from dbt_common.helper_types import PathSet
from dbt_common.events.functions import fire_event
from dbt_common.events.contextvars import get_node_info
from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.events.types import UnpinnedRefNewVersionAvailable
from dbt.node_types import NodeType, AccessType, REFABLE_NODE_TYPES, VERSIONED_NODE_TYPES
from dbt.mp_context import get_mp_context
import dbt_common.utils
Expand Down Expand Up @@ -1466,50 +1467,36 @@ def is_invalid_protected_ref(
node.package_name != target_model.package_name and restrict_package_access
)

# Called by GraphRunnableTask.defer_to_manifest
def merge_from_artifact(
self,
adapter,
other: "Manifest",
selected: AbstractSet[UniqueID],
favor_state: bool = False,
) -> None:
"""Given the selected unique IDs and a writable manifest, update this
manifest by replacing any unselected nodes with their counterpart.
# Called in GraphRunnableTask.before_run, RunTask.before_run, CloneTask.before_run
def merge_from_artifact(self, other: "Manifest") -> None:
"""Update this manifest by adding the 'defer_relation' attribute to all nodes
with a counterpart in the stateful manifest used for deferral.
Only non-ephemeral refable nodes are examined.
"""
refables = set(REFABLE_NODE_TYPES)
merged = set()
for unique_id, node in other.nodes.items():
current = self.nodes.get(unique_id)
if current and (
node.resource_type in refables
and not node.is_ephemeral
and unique_id not in selected
and (
not adapter.get_relation(current.database, current.schema, current.identifier)
or favor_state
)
):
merged.add(unique_id)
self.nodes[unique_id] = replace(node, deferred=True)

# for all other nodes, add 'defer_relation'
elif current and node.resource_type in refables and not node.is_ephemeral:
if current and node.resource_type in refables and not node.is_ephemeral:
assert isinstance(node.config, NodeConfig) # this makes mypy happy
defer_relation = DeferRelation(
node.database, node.schema, node.alias, node.relation_name
database=node.database,
schema=node.schema,
alias=node.alias,
relation_name=node.relation_name,
resource_type=node.resource_type,
name=node.name,
description=node.description,
compiled_code=(node.compiled_code if not isinstance(node, SeedNode) else None),
meta=node.meta,
tags=node.tags,
config=node.config,
)
self.nodes[unique_id] = replace(current, defer_relation=defer_relation)

# Rebuild the flat_graph, which powers the 'graph' context variable,
# now that we've deferred some nodes
# Rebuild the flat_graph, which powers the 'graph' context variable
self.build_flat_graph()

# log up to 5 items
sample = list(islice(merged, 5))
fire_event(MergedFromState(num_merged=len(merged), sample=sample))

# Methods that were formerly in ParseResult
def add_macro(self, source_file: SourceFile, macro: Macro):
if macro.unique_id in self.macros:
Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,11 @@ def language(self):
return "sql"


# @property
# def compiled_code(self):
# return None


# ====================================
# Singular Test node
# ====================================
Expand Down
7 changes: 1 addition & 6 deletions core/dbt/events/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ def message(self) -> str:
return f"Tracking: {self.user_state}"


class MergedFromState(DebugLevel):
def code(self) -> str:
return "A004"

def message(self) -> str:
return f"Merged {self.num_merged} items from state (sample: {self.sample})"
# Removed A004: MergedFromState


class MissingProfileTarget(InfoLevel):
Expand Down
8 changes: 7 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,12 @@ def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] =
write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
def parse_manifest(
runtime_config: RuntimeConfig,
write_perf_info: bool,
write: bool,
write_json: bool,
) -> Manifest:
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)
Expand All @@ -1895,6 +1900,7 @@ def parse_manifest(runtime_config, write_perf_info, write, write_json):
write_perf_info=write_perf_info,
)

# If we should (over)write the manifest in the target path, do that now
if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
NodeType.Seed,
NodeType.Snapshot,
):
input_node = ModelNode(**common_fields)
input_node = ModelNode(
**common_fields,
defer_relation=original_input_node.defer_relation,
)
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.version
Expand Down
5 changes: 2 additions & 3 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe

def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
# unlike in other tasks, we want to add information from the --state manifest *before* caching!
self.defer_to_manifest(adapter, selected_uids)
# only create *our* schemas, but cache *other* schemas in addition
self.defer_to_manifest()
# only create target schemas, but also cache defer_relation schemas
schemas_to_create = super().get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, schemas_to_create)
schemas_to_cache = self.get_model_schemas(adapter, selected_uids)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,10 @@ def print_results_line(self, results, execution_time) -> None:

def before_run(self, adapter, selected_uids: AbstractSet[str]) -> None:
with adapter.connection_named("master"):
self.defer_to_manifest()
required_schemas = self.get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, required_schemas)
self.populate_adapter_cache(adapter, required_schemas)
self.defer_to_manifest(adapter, selected_uids)
self.safe_run_hooks(adapter, RunHookType.Start, {})

def after_run(self, adapter, results) -> None:
Expand Down
15 changes: 3 additions & 12 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,15 @@ def get_selection_spec(self) -> SelectionSpec:
def get_node_selector(self) -> NodeSelector:
raise NotImplementedError(f"get_node_selector not implemented for task {type(self)}")

def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]):
def defer_to_manifest(self):
deferred_manifest = self._get_deferred_manifest()
if deferred_manifest is None:
return
if self.manifest is None:
raise DbtInternalError(
"Expected to defer to manifest, but there is no runtime manifest to defer from!"
)
self.manifest.merge_from_artifact(
adapter=adapter,
other=deferred_manifest,
selected=selected_uids,
favor_state=bool(self.args.favor_state),
)
# We're rewriting the manifest because it's been mutated during merge_from_artifact.
# This is to reflect which nodes had been deferred to (= replaced with) their counterparts.
if self.args.write_json:
write_manifest(self.manifest, self.config.project_target_path)
self.manifest.merge_from_artifact(other=deferred_manifest)

def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector()
Expand Down Expand Up @@ -479,8 +470,8 @@ def populate_adapter_cache(

def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
self.defer_to_manifest()
self.populate_adapter_cache(adapter)
self.defer_to_manifest(adapter, selected_uids)

def after_run(self, adapter, results):
pass
Expand Down
4 changes: 1 addition & 3 deletions tests/functional/adapter/dbt_clone/test_dbt_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,14 @@ def copy_state(self, project_root):
def run_and_save_state(self, project_root, with_snapshot=False):
results = run_dbt(["seed"])
assert len(results) == 1
assert not any(r.node.deferred for r in results)
results = run_dbt(["run"])
assert len(results) == 2
assert not any(r.node.deferred for r in results)
results = run_dbt(["test"])
assert len(results) == 2

if with_snapshot:
results = run_dbt(["snapshot"])
assert len(results) == 1
assert not any(r.node.deferred for r in results)

# copy files
self.copy_state(project_root)
Expand Down Expand Up @@ -226,6 +223,7 @@ def test_clone_same_target_and_state(self, project, unique_schema, other_schema)

clone_args = [
"clone",
"--defer",
"--state",
"target",
]
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/artifacts/data/state/v12/manifest.json

Large diffs are not rendered by default.

Loading

0 comments on commit 487a532

Please sign in to comment.