Skip to content

Commit

Permalink
Move deferral from task to manifest loading + RefResolver
Browse files Browse the repository at this point in the history
  • Loading branch information
jtcohen6 committed Dec 4, 2023
1 parent ed0c432 commit fa16a8c
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 105 deletions.
30 changes: 28 additions & 2 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.contracts.state import PreviousState
from dbt.events.base_types import EventLevel
from dbt.events.functions import fire_event, LOG_VERSION, set_invocation_id, setup_event_logger
from dbt.events.types import (
Expand All @@ -20,7 +21,12 @@
)
from dbt.events.helpers import get_json_string_utcnow
from dbt.events.types import MainEncounteredError, MainStackTrace
from dbt.exceptions import Exception as DbtException, DbtProjectError, FailFastError
from dbt.exceptions import (
Exception as DbtException,
DbtProjectError,
DbtRuntimeError,
FailFastError,
)
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
Expand All @@ -30,6 +36,7 @@
from click import Context
from functools import update_wrapper
import importlib.util
from pathlib import Path
import time
import traceback

Expand Down Expand Up @@ -265,6 +272,7 @@ def wrapper(*args, **kwargs):

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config)
flags: Flags = ctx.obj["flags"]

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
Expand All @@ -273,8 +281,26 @@ def wrapper(*args, **kwargs):
write_perf_info=write_perf_info,
)

# If deferral is enabled, add 'defer_relation' attribute to all nodes
if flags.defer:
defer_state = flags.defer_state or flags.state
if not defer_state:
raise DbtRuntimeError(
"Deferral is enabled and requires a stateful manifest, but none was provided"
)
previous_state = PreviousState(
state_path=defer_state,
target_path=Path(runtime_config.target_path),
project_root=Path(runtime_config.project_root),
)
if not previous_state.manifest:
raise DbtRuntimeError(
f'Could not find manifest in deferral state path: "{defer_state}"'
)
manifest.merge_from_artifact(previous_state.manifest)

ctx.obj["manifest"] = manifest
if write and ctx.obj["flags"].write_json:
if write and flags.write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
Expand Down
21 changes: 20 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def resolve(
self.model.package_name,
)

# Raise an error if the reference target is missing
if target_model is None or isinstance(target_model, Disabled):
raise TargetNotFoundError(
node=self.model,
Expand All @@ -510,6 +511,8 @@ def resolve(
target_version=target_version,
disabled=isinstance(target_model, Disabled),
)

# Raise error if trying to reference a 'private' resource outside its 'group'
elif self.manifest.is_invalid_private_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -519,6 +522,7 @@ def resolve(
access=AccessType.Private,
scope=cast_to_str(target_model.group),
)
# Or a 'protected' resource outside its project/package namespace
elif self.manifest.is_invalid_protected_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -528,7 +532,6 @@ def resolve(
access=AccessType.Protected,
scope=target_model.package_name,
)

self.validate(target_model, target_name, target_package, target_version)
return self.create_relation(target_model)

Expand All @@ -538,6 +541,22 @@ def create_relation(self, target_model: ManifestNode) -> RelationProxy:
return self.Relation.create_ephemeral_from_node(
self.config, target_model, limit=self.resolve_limit
)
elif (
hasattr(target_model, "defer_relation")
and target_model.defer_relation
and self.config.args.defer
and (
# User has explicitly opted to prefer defer_relation
self.config.args.favor_state
# Or, this node's relation does not exist in the expected target location (cache lookup)
or not get_adapter(self.config).get_relation(
target_model.database, target_model.schema, target_model.identifier
)
)
):
return self.Relation.create_from_node(
self.config, target_model.defer_relation, limit=self.resolve_limit
)
else:
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)

Expand Down
44 changes: 8 additions & 36 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import chain, islice
from itertools import chain
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from typing import (
Expand All @@ -18,7 +18,6 @@
TypeVar,
Callable,
Generic,
AbstractSet,
ClassVar,
Iterable,
)
Expand Down Expand Up @@ -63,7 +62,7 @@
)
from dbt.helper_types import PathSet
from dbt.events.functions import fire_event
from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.events.types import UnpinnedRefNewVersionAvailable
from dbt.events.contextvars import get_node_info
from dbt.node_types import NodeType, AccessType
from dbt.flags import get_flags, MP_CONTEXT
Expand Down Expand Up @@ -1340,50 +1339,25 @@ def is_invalid_protected_ref(
node.package_name != target_model.package_name and restrict_package_access
)

# Called by GraphRunnableTask.defer_to_manifest
def merge_from_artifact(
self,
adapter,
other: "WritableManifest",
selected: AbstractSet[UniqueID],
favor_state: bool = False,
) -> None:
"""Given the selected unique IDs and a writable manifest, update this
manifest by replacing any unselected nodes with their counterpart.
# Called by requires.manifest after ManifestLoader.get_full_manifest
def merge_from_artifact(self, other: "WritableManifest") -> None:
"""Update this manifest by adding the 'defer_relation' attribute to all nodes
with a counterpart in the stateful manifest used for deferral.
Only non-ephemeral refable nodes are examined.
"""
refables = set(NodeType.refable())
merged = set()
for unique_id, node in other.nodes.items():
current = self.nodes.get(unique_id)
if current and (
node.resource_type in refables
and not node.is_ephemeral
and unique_id not in selected
and (
not adapter.get_relation(current.database, current.schema, current.identifier)
or favor_state
)
):
merged.add(unique_id)
self.nodes[unique_id] = node.replace(deferred=True)

# for all other nodes, add 'defer_relation'
elif current and node.resource_type in refables and not node.is_ephemeral:
if current and node.resource_type in refables and not node.is_ephemeral:
defer_relation = DeferRelation(
node.database, node.schema, node.alias, node.relation_name
)
self.nodes[unique_id] = current.replace(defer_relation=defer_relation)

# Rebuild the flat_graph, which powers the 'graph' context variable,
# now that we've deferred some nodes
# Rebuild the flat_graph, which powers the 'graph' context variable
self.build_flat_graph()

# log up to 5 items
sample = list(islice(merged, 5))
fire_event(MergedFromState(num_merged=len(merged), sample=sample))

# Methods that were formerly in ParseResult

def add_macro(self, source_file: SourceFile, macro: Macro):
Expand Down Expand Up @@ -1621,8 +1595,6 @@ def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
if "config_call_dict" in node:
del node["config_call_dict"]
if "defer_relation" in node:
del node["defer_relation"]
return dct


Expand Down
7 changes: 1 addition & 6 deletions core/dbt/events/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,7 @@ def message(self) -> str:
return f"Tracking: {self.user_state}"


class MergedFromState(DebugLevel):
def code(self) -> str:
return "A004"

def message(self) -> str:
return f"Merged {self.num_merged} items from state (sample: {self.sample})"
# Removed A004: MergedFromState


class MissingProfileTarget(InfoLevel):
Expand Down
10 changes: 1 addition & 9 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import threading
from typing import AbstractSet, Any, List, Iterable, Set, Optional
from typing import AbstractSet, Any, List, Iterable, Set

from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.results import RunStatus, RunResult
from dbt.dataclass_schema import dbtClassMixin
from dbt.exceptions import DbtInternalError, CompilationError
Expand Down Expand Up @@ -94,11 +93,6 @@ class CloneTask(GraphRunnableTask):
def raise_on_first_error(self):
return False

def _get_deferred_manifest(self) -> Optional[WritableManifest]:
# Unlike other commands, 'clone' always requires a state manifest
# Load previous state, regardless of whether --defer flag has been set
return self._get_previous_state()

def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRelation]:
if self.manifest is None:
raise DbtInternalError("manifest was None in get_model_schemas")
Expand All @@ -122,8 +116,6 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe

def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
# unlike in other tasks, we want to add information from the --state manifest *before* caching!
self.defer_to_manifest(adapter, selected_uids)
# only create *our* schemas, but cache *other* schemas in addition
schemas_to_create = super().get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, schemas_to_create)
Expand Down
1 change: 0 additions & 1 deletion core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def before_run(self, adapter, selected_uids: AbstractSet[str]) -> None:
required_schemas = self.get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, required_schemas)
self.populate_adapter_cache(adapter, required_schemas)
self.defer_to_manifest(adapter, selected_uids)
self.safe_run_hooks(adapter, RunHookType.Start, {})

def after_run(self, adapter, results) -> None:
Expand Down
41 changes: 0 additions & 41 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import dbt.utils
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import (
NodeStatus,
Expand Down Expand Up @@ -76,7 +75,6 @@ def __init__(self, args, config, manifest) -> None:
self.node_results: List[BaseResult] = []
self.num_nodes: int = 0
self.previous_state: Optional[PreviousState] = None
self.previous_defer_state: Optional[PreviousState] = None
self.run_count: int = 0
self.started_at: float = 0

Expand All @@ -87,13 +85,6 @@ def __init__(self, args, config, manifest) -> None:
project_root=Path(self.config.project_root),
)

if self.args.defer_state:
self.previous_defer_state = PreviousState(
state_path=self.args.defer_state,
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
)

def index_offset(self, value: int) -> int:
return value

Expand Down Expand Up @@ -130,23 +121,6 @@ def get_selection_spec(self) -> SelectionSpec:
def get_node_selector(self) -> NodeSelector:
raise NotImplementedError(f"get_node_selector not implemented for task {type(self)}")

def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]):
deferred_manifest = self._get_deferred_manifest()
if deferred_manifest is None:
return
if self.manifest is None:
raise DbtInternalError(
"Expected to defer to manifest, but there is no runtime manifest to defer from!"
)
self.manifest.merge_from_artifact(
adapter=adapter,
other=deferred_manifest,
selected=selected_uids,
favor_state=bool(self.args.favor_state),
)
# TODO: is it wrong to write the manifest here? I think it's right...
write_manifest(self.manifest, self.config.project_target_path)

def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector()
spec = self.get_selection_spec()
Expand Down Expand Up @@ -432,7 +406,6 @@ def populate_adapter_cache(
def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
self.populate_adapter_cache(adapter)
self.defer_to_manifest(adapter, selected_uids)

def after_run(self, adapter, results):
pass
Expand Down Expand Up @@ -616,17 +589,3 @@ def get_result(self, results, elapsed_time, generated_at):

def task_end_messages(self, results):
print_run_end_messages(results)

def _get_previous_state(self) -> Optional[WritableManifest]:
state = self.previous_defer_state or self.previous_state
if not state:
raise DbtRuntimeError(
"--state or --defer-state are required for deferral, but neither was provided"
)

if not state.manifest:
raise DbtRuntimeError(f'Could not find manifest in --state path: "{state}"')
return state.manifest

def _get_deferred_manifest(self) -> Optional[WritableManifest]:
return self._get_previous_state() if self.args.defer else None
5 changes: 0 additions & 5 deletions tests/functional/defer_state/test_defer_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import shutil
from copy import deepcopy
Expand Down Expand Up @@ -181,10 +180,6 @@ def test_run_and_defer(self, project, unique_schema, other_schema):
assert other_schema not in results[0].node.compiled_code
assert unique_schema in results[0].node.compiled_code

with open("target/manifest.json") as fp:
data = json.load(fp)
assert data["nodes"]["seed.test.seed"]["deferred"]

assert len(results) == 1


Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def test_event_codes(self):
types.MainReportVersion(version=""),
types.MainReportArgs(args={}),
types.MainTrackingUserState(user_state=""),
types.MergedFromState(num_merged=0, sample=[]),
types.MissingProfileTarget(profile_name="", target_name=""),
types.InvalidOptionYAML(option_name="vars"),
types.LogDbtProjectError(),
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def test_build_flat_graph(self):
self.assertEqual(frozenset(node), REQUIRED_PARSED_NODE_KEYS)
self.assertEqual(compiled_count, 2)

def test_add_from_artifact(self):
def test_merge_from_artifact(self):
original_nodes = deepcopy(self.nested_nodes)
other_nodes = deepcopy(self.nested_nodes)

Expand All @@ -1041,8 +1041,7 @@ def test_add_from_artifact(self):

original_manifest = Manifest(nodes=original_nodes)
other_manifest = Manifest(nodes=other_nodes)
adapter = mock.MagicMock()
original_manifest.merge_from_artifact(adapter, other_manifest.writable_manifest(), {})
original_manifest.merge_from_artifact(other_manifest.writable_manifest())

# new node added should not be in original manifest
assert "model.root.nested2" not in original_manifest.nodes
Expand Down

0 comments on commit fa16a8c

Please sign in to comment.