Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move deferral resolution from merge_from_artifact to RuntimeRefResolver #9199

Merged
merged 10 commits into from
May 2, 2024
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20240201-003033.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Split up deferral across parsing (adding 'defer_relation' from state manifest)
and runtime ref resolution"
time: 2024-02-01T00:30:33.573665+01:00
custom:
Author: jtcohen6
Issue: "9199"
1 change: 0 additions & 1 deletion core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class ParsedResource(ParsedResourceMandatory):
docs: Docs = field(default_factory=Docs)
patch_path: Optional[str] = None
build_path: Optional[str] = None
deferred: bool = False
Copy link
Contributor

@MichelleArk MichelleArk May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirming this is a minor change and safe to make here from a runtime serialization/deserialization perspective 👍 Independent unit tests here: #10066

Additionally, we're not using this field for anything functional in dbt-core's implementation, it was previously just accessed in testing to assert certain behaviour had occurred during deferral.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving minor schema evolution changes

unrendered_config: Dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=lambda: time.time())
config_call_dict: Dict[str, Any] = field(default_factory=dict)
Expand Down
2 changes: 0 additions & 2 deletions core/dbt/artifacts/schemas/manifest/v12/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,4 @@ def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
if "config_call_dict" in node:
del node["config_call_dict"]
if "defer_relation" in node:
del node["defer_relation"]
return dct
2 changes: 1 addition & 1 deletion core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def wrapper(*args, **kwargs):

runtime_config = ctx.obj["runtime_config"]

# a manifest has already been set on the context, so don't overwrite it
# if a manifest has already been set on the context, don't overwrite it
if ctx.obj.get("manifest") is None:
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
Expand Down
21 changes: 20 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def resolve(
self.model.package_name,
)

# Raise an error if the reference target is missing
if target_model is None or isinstance(target_model, Disabled):
raise TargetNotFoundError(
node=self.model,
Expand All @@ -513,6 +514,8 @@ def resolve(
target_version=target_version,
disabled=isinstance(target_model, Disabled),
)

# Raise error if trying to reference a 'private' resource outside its 'group'
elif self.manifest.is_invalid_private_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -522,6 +525,7 @@ def resolve(
access=AccessType.Private,
scope=cast_to_str(target_model.group),
)
# Or a 'protected' resource outside its project/package namespace
elif self.manifest.is_invalid_protected_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -531,14 +535,29 @@ def resolve(
access=AccessType.Protected,
scope=target_model.package_name,
)

self.validate(target_model, target_name, target_package, target_version)
return self.create_relation(target_model)

def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from(target_model, limit=self.resolve_limit)
elif (
hasattr(target_model, "defer_relation")
and target_model.defer_relation
and self.config.args.defer
and (
# User has explicitly opted to prefer defer_relation
self.config.args.favor_state
# Or, this node's relation does not exist in the expected target location (cache lookup)
or not get_adapter(self.config).get_relation(
target_model.database, target_model.schema, target_model.identifier
)
)
):
return self.Relation.create_from(
self.config, target_model.defer_relation, limit=self.resolve_limit
)
else:
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)

Expand Down
42 changes: 8 additions & 34 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field, replace
from itertools import chain, islice
from itertools import chain
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from typing import (
Expand All @@ -18,7 +18,6 @@
TypeVar,
Callable,
Generic,
AbstractSet,
ClassVar,
)
from typing_extensions import Protocol
Expand Down Expand Up @@ -74,7 +73,7 @@
from dbt_common.helper_types import PathSet
from dbt_common.events.functions import fire_event
from dbt_common.events.contextvars import get_node_info
from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.events.types import UnpinnedRefNewVersionAvailable
from dbt.node_types import NodeType, AccessType, REFABLE_NODE_TYPES, VERSIONED_NODE_TYPES
from dbt.mp_context import get_mp_context
import dbt_common.utils
Expand Down Expand Up @@ -1466,50 +1465,25 @@ def is_invalid_protected_ref(
node.package_name != target_model.package_name and restrict_package_access
)

# Called by GraphRunnableTask.defer_to_manifest
def merge_from_artifact(
self,
adapter,
other: "Manifest",
selected: AbstractSet[UniqueID],
favor_state: bool = False,
) -> None:
"""Given the selected unique IDs and a writable manifest, update this
manifest by replacing any unselected nodes with their counterpart.
# Called by requires.manifest after ManifestLoader.get_full_manifest
def merge_from_artifact(self, other: "Manifest") -> None:
"""Update this manifest by adding the 'defer_relation' attribute to all nodes
with a counterpart in the stateful manifest used for deferral.

Only non-ephemeral refable nodes are examined.
"""
refables = set(REFABLE_NODE_TYPES)
merged = set()
for unique_id, node in other.nodes.items():
current = self.nodes.get(unique_id)
if current and (
node.resource_type in refables
and not node.is_ephemeral
and unique_id not in selected
and (
not adapter.get_relation(current.database, current.schema, current.identifier)
or favor_state
)
):
merged.add(unique_id)
self.nodes[unique_id] = replace(node, deferred=True)

# for all other nodes, add 'defer_relation'
elif current and node.resource_type in refables and not node.is_ephemeral:
if current and node.resource_type in refables and not node.is_ephemeral:
defer_relation = DeferRelation(
node.database, node.schema, node.alias, node.relation_name
)
self.nodes[unique_id] = replace(current, defer_relation=defer_relation)

# Rebuild the flat_graph, which powers the 'graph' context variable,
# now that we've deferred some nodes
# Rebuild the flat_graph, which powers the 'graph' context variable
self.build_flat_graph()

# log up to 5 items
sample = list(islice(merged, 5))
fire_event(MergedFromState(num_merged=len(merged), sample=sample))

# Methods that were formerly in ParseResult
def add_macro(self, source_file: SourceFile, macro: Macro):
if macro.unique_id in self.macros:
Expand Down
7 changes: 1 addition & 6 deletions core/dbt/events/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ def message(self) -> str:
return f"Tracking: {self.user_state}"


class MergedFromState(DebugLevel):
def code(self) -> str:
return "A004"

def message(self) -> str:
return f"Merged {self.num_merged} items from state (sample: {self.sample})"
# Removed A004: MergedFromState
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're no longer fully overwriting deferred manifest nodes with the stateful manifest's alternative, I didn't see much use in keeping this event around



class MissingProfileTarget(InfoLevel):
Expand Down
30 changes: 29 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import field
import datetime
import os
from pathlib import Path
import traceback
from typing import (
Dict,
Expand All @@ -21,6 +22,7 @@

from dbt.context.query_header import generate_query_header_context
from dbt.contracts.graph.semantic_manifest import SemanticManifest
from dbt.contracts.state import PreviousState
from dbt_common.events.base_types import EventLevel
from dbt_common.exceptions.base import DbtValidationError
import dbt_common.utils
Expand Down Expand Up @@ -117,6 +119,7 @@
TargetNotFoundError,
AmbiguousAliasError,
InvalidAccessTypeError,
DbtRuntimeError,
scrub_secrets,
)
from dbt.parser.base import Parser
Expand Down Expand Up @@ -1886,7 +1889,12 @@ def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] =
write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
def parse_manifest(
runtime_config: RuntimeConfig,
write_perf_info: bool,
write: bool,
write_json: bool,
) -> Manifest:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before 6b63e31, this was where I was calling merge_from_artifact. Now this file diff is just adding type annotations and an inline comment.

register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)
Expand All @@ -1895,6 +1903,26 @@ def parse_manifest(runtime_config, write_perf_info, write, write_json):
write_perf_info=write_perf_info,
)

# If deferral is enabled, add 'defer_relation' attribute to all nodes
flags = get_flags()
if flags.defer:
defer_state_path = flags.defer_state or flags.state
if not defer_state_path:
raise DbtRuntimeError(
"Deferral is enabled and requires a stateful manifest, but none was provided"
)
defer_state = PreviousState(
state_path=defer_state_path,
target_path=Path(runtime_config.target_path),
project_root=Path(runtime_config.project_root),
)
if not defer_state.manifest:
raise DbtRuntimeError(
f'Could not find manifest in deferral state path: "{defer_state_path}"'
)
manifest.merge_from_artifact(defer_state.manifest)

# If we should (over)write the manifest in the target path, do that now
if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
NodeType.Seed,
NodeType.Snapshot,
):
input_node = ModelNode(**common_fields)
input_node = ModelNode(
**common_fields,
defer_relation=original_input_node.defer_relation,
)
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.version
Expand Down
10 changes: 1 addition & 9 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import threading
from typing import AbstractSet, Any, List, Iterable, Set, Optional
from typing import AbstractSet, Any, List, Iterable, Set

from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.manifest import Manifest
from dbt.artifacts.schemas.run import RunStatus, RunResult
from dbt_common.dataclass_schema import dbtClassMixin
from dbt_common.exceptions import DbtInternalError, CompilationError
Expand Down Expand Up @@ -94,11 +93,6 @@ class CloneTask(GraphRunnableTask):
def raise_on_first_error(self):
return False

def _get_deferred_manifest(self) -> Optional[Manifest]:
# Unlike other commands, 'clone' always requires a state manifest
# Load previous state, regardless of whether --defer flag has been set
return self._get_previous_state()

def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRelation]:
if self.manifest is None:
raise DbtInternalError("manifest was None in get_model_schemas")
Expand All @@ -122,8 +116,6 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe

def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
# unlike in other tasks, we want to add information from the --state manifest *before* caching!
self.defer_to_manifest(adapter, selected_uids)
# only create *our* schemas, but cache *other* schemas in addition
schemas_to_create = super().get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, schemas_to_create)
Expand Down
1 change: 0 additions & 1 deletion core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ def before_run(self, adapter, selected_uids: AbstractSet[str]) -> None:
required_schemas = self.get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, required_schemas)
self.populate_adapter_cache(adapter, required_schemas)
self.defer_to_manifest(adapter, selected_uids)
self.safe_run_hooks(adapter, RunHookType.Start, {})

def after_run(self, adapter, results) -> None:
Expand Down
44 changes: 2 additions & 42 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> No
self.job_queue: Optional[GraphQueue] = None
self.node_results: List[BaseResult] = []
self.num_nodes: int = 0
# TODO: if --defer is enabled, we have already loaded the "previous state" artifacts into memory
# can we check to see, and reuse them if so?
self.previous_state: Optional[PreviousState] = None
self.previous_defer_state: Optional[PreviousState] = None
self.run_count: int = 0
self.started_at: float = 0

Expand All @@ -88,13 +89,6 @@ def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> No
project_root=Path(self.config.project_root),
)

if self.args.defer_state:
self.previous_defer_state = PreviousState(
state_path=self.args.defer_state,
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
)

def index_offset(self, value: int) -> int:
return value

Expand Down Expand Up @@ -127,25 +121,6 @@ def get_selection_spec(self) -> SelectionSpec:
def get_node_selector(self) -> NodeSelector:
raise NotImplementedError(f"get_node_selector not implemented for task {type(self)}")

def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]):
deferred_manifest = self._get_deferred_manifest()
if deferred_manifest is None:
return
if self.manifest is None:
raise DbtInternalError(
"Expected to defer to manifest, but there is no runtime manifest to defer from!"
)
self.manifest.merge_from_artifact(
adapter=adapter,
other=deferred_manifest,
selected=selected_uids,
favor_state=bool(self.args.favor_state),
)
# We're rewriting the manifest because it's been mutated during merge_from_artifact.
# This is to reflect which nodes had been deferred to (= replaced with) their counterparts.
if self.args.write_json:
write_manifest(self.manifest, self.config.project_target_path)

def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector()
# Following uses self.selection_arg and self.exclusion_arg
Expand Down Expand Up @@ -480,7 +455,6 @@ def populate_adapter_cache(
def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
self.populate_adapter_cache(adapter)
self.defer_to_manifest(adapter, selected_uids)

def after_run(self, adapter, results):
pass
Expand Down Expand Up @@ -666,17 +640,3 @@ def get_result(self, results, elapsed_time, generated_at):

def task_end_messages(self, results):
print_run_end_messages(results)

def _get_previous_state(self) -> Optional[Manifest]:
state = self.previous_defer_state or self.previous_state
if not state:
raise DbtRuntimeError(
"--state or --defer-state are required for deferral, but neither was provided"
)

if not state.manifest:
raise DbtRuntimeError(f'Could not find manifest in --state path: "{state.state_path}"')
return state.manifest

def _get_deferred_manifest(self) -> Optional[Manifest]:
return self._get_previous_state() if self.args.defer else None
Loading
Loading