Skip to content

Commit

Permalink
Move deferral from task to manifest loading + RefResolver
Browse files Browse the repository at this point in the history
  • Loading branch information
jtcohen6 committed Jan 31, 2024
1 parent 9c8b28a commit 7d962ed
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 107 deletions.
2 changes: 0 additions & 2 deletions core/dbt/artifacts/schemas/manifest/v12/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,4 @@ def __post_serialize__(self, dct):
for unique_id, node in dct["nodes"].items():
if "config_call_dict" in node:
del node["config_call_dict"]
if "defer_relation" in node:
del node["defer_relation"]
return dct
2 changes: 1 addition & 1 deletion core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def wrapper(*args, **kwargs):

runtime_config = ctx.obj["runtime_config"]

# a manifest has already been set on the context, so don't overwrite it
# if a manifest has already been set on the context, don't overwrite it
if ctx.obj.get("manifest") is None:
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
Expand Down
21 changes: 20 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def resolve(
self.model.package_name,
)

# Raise an error if the reference target is missing
if target_model is None or isinstance(target_model, Disabled):
raise TargetNotFoundError(
node=self.model,
Expand All @@ -508,6 +509,8 @@ def resolve(
target_version=target_version,
disabled=isinstance(target_model, Disabled),
)

# Raise error if trying to reference a 'private' resource outside its 'group'
elif self.manifest.is_invalid_private_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -517,6 +520,7 @@ def resolve(
access=AccessType.Private,
scope=cast_to_str(target_model.group),
)
# Or a 'protected' resource outside its project/package namespace
elif self.manifest.is_invalid_protected_ref(
self.model, target_model, self.config.dependencies
):
Expand All @@ -526,14 +530,29 @@ def resolve(
access=AccessType.Protected,
scope=target_model.package_name,
)

self.validate(target_model, target_name, target_package, target_version)
return self.create_relation(target_model)

def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from(target_model, limit=self.resolve_limit)
elif (
hasattr(target_model, "defer_relation")
and target_model.defer_relation
and self.config.args.defer
and (
# User has explicitly opted to prefer defer_relation
self.config.args.favor_state
# Or, this node's relation does not exist in the expected target location (cache lookup)
or not get_adapter(self.config).get_relation(
target_model.database, target_model.schema, target_model.identifier
)
)
):
return self.Relation.create_from(
self.config, target_model.defer_relation, limit=self.resolve_limit
)
else:
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)

Expand Down
42 changes: 8 additions & 34 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import chain, islice
from itertools import chain
from mashumaro.mixins.msgpack import DataClassMessagePackMixin
from multiprocessing.synchronize import Lock
from typing import (
Expand All @@ -18,7 +18,6 @@
TypeVar,
Callable,
Generic,
AbstractSet,
ClassVar,
)
from typing_extensions import Protocol
Expand Down Expand Up @@ -67,7 +66,7 @@
from dbt_common.helper_types import PathSet
from dbt_common.events.functions import fire_event
from dbt_common.events.contextvars import get_node_info
from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.events.types import UnpinnedRefNewVersionAvailable
from dbt.node_types import NodeType, AccessType, REFABLE_NODE_TYPES, VERSIONED_NODE_TYPES
from dbt.mp_context import get_mp_context
import dbt_common.utils
Expand Down Expand Up @@ -1351,50 +1350,25 @@ def is_invalid_protected_ref(
node.package_name != target_model.package_name and restrict_package_access
)

# Called by GraphRunnableTask.defer_to_manifest
def merge_from_artifact(
self,
adapter,
other: "WritableManifest",
selected: AbstractSet[UniqueID],
favor_state: bool = False,
) -> None:
"""Given the selected unique IDs and a writable manifest, update this
manifest by replacing any unselected nodes with their counterpart.
# Called by requires.manifest after ManifestLoader.get_full_manifest
def merge_from_artifact(self, other: "WritableManifest") -> None:
"""Update this manifest by adding the 'defer_relation' attribute to all nodes
with a counterpart in the stateful manifest used for deferral.
Only non-ephemeral refable nodes are examined.
"""
refables = set(REFABLE_NODE_TYPES)
merged = set()
for unique_id, node in other.nodes.items():
current = self.nodes.get(unique_id)
if current and (
node.resource_type in refables
and not node.is_ephemeral
and unique_id not in selected
and (
not adapter.get_relation(current.database, current.schema, current.identifier)
or favor_state
)
):
merged.add(unique_id)
self.nodes[unique_id] = node.replace(deferred=True)

# for all other nodes, add 'defer_relation'
elif current and node.resource_type in refables and not node.is_ephemeral:
if current and node.resource_type in refables and not node.is_ephemeral:
defer_relation = DeferRelation(
node.database, node.schema, node.alias, node.relation_name
)
self.nodes[unique_id] = current.replace(defer_relation=defer_relation)

# Rebuild the flat_graph, which powers the 'graph' context variable,
# now that we've deferred some nodes
# Rebuild the flat_graph, which powers the 'graph' context variable
self.build_flat_graph()

# log up to 5 items
sample = list(islice(merged, 5))
fire_event(MergedFromState(num_merged=len(merged), sample=sample))

# Methods that were formerly in ParseResult
def add_macro(self, source_file: SourceFile, macro: Macro):
if macro.unique_id in self.macros:
Expand Down
1 change: 0 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ class ParsedNode(NodeInfoMixin, ParsedNodeMandatory, SerializableType):
docs: Docs = field(default_factory=Docs)
patch_path: Optional[str] = None
build_path: Optional[str] = None
deferred: bool = False
unrendered_config: Dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=lambda: time.time())
config_call_dict: Dict[str, Any] = field(default_factory=dict)
Expand Down
7 changes: 1 addition & 6 deletions core/dbt/events/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ def message(self) -> str:
return f"Tracking: {self.user_state}"


class MergedFromState(DebugLevel):
def code(self) -> str:
return "A004"

def message(self) -> str:
return f"Merged {self.num_merged} items from state (sample: {self.sample})"
# Removed A004: MergedFromState


class MissingProfileTarget(InfoLevel):
Expand Down
30 changes: 29 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import field
import datetime
import os
from pathlib import Path
import traceback
from typing import (
Dict,
Expand All @@ -21,6 +22,7 @@

from dbt.context.manifest import generate_query_header_context
from dbt.contracts.graph.semantic_manifest import SemanticManifest
from dbt.contracts.state import PreviousState
from dbt_common.events.base_types import EventLevel
import dbt_common.utils
import json
Expand Down Expand Up @@ -114,6 +116,7 @@
TargetNotFoundError,
AmbiguousAliasError,
InvalidAccessTypeError,
DbtRuntimeError,
)
from dbt.parser.base import Parser
from dbt.parser.analysis import AnalysisParser
Expand Down Expand Up @@ -1840,7 +1843,12 @@ def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] =
write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
def parse_manifest(
runtime_config: RuntimeConfig,
write_perf_info: bool,
write: bool,
write_json: bool,
) -> Manifest:
register_adapter(runtime_config, get_mp_context())
adapter = get_adapter(runtime_config)
adapter.set_macro_context_generator(generate_runtime_macro_context)
Expand All @@ -1849,6 +1857,26 @@ def parse_manifest(runtime_config, write_perf_info, write, write_json):
write_perf_info=write_perf_info,
)

# If deferral is enabled, add 'defer_relation' attribute to all nodes
flags = get_flags()
if flags.defer:
defer_state_path = flags.defer_state or flags.state
if not defer_state_path:
raise DbtRuntimeError(
"Deferral is enabled and requires a stateful manifest, but none was provided"
)
defer_state = PreviousState(
state_path=defer_state_path,
target_path=Path(runtime_config.target_path),
project_root=Path(runtime_config.project_root),
)
if not defer_state.manifest:
raise DbtRuntimeError(
f'Could not find manifest in deferral state path: "{defer_state_path}"'
)
manifest.merge_from_artifact(defer_state.manifest)

# If we should (over)write the manifest in the target path, do that now
if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
Expand Down
10 changes: 1 addition & 9 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import threading
from typing import AbstractSet, Any, List, Iterable, Set, Optional
from typing import AbstractSet, Any, List, Iterable, Set

from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.manifest import WritableManifest
from dbt.artifacts.schemas.run import RunStatus, RunResult
from dbt_common.dataclass_schema import dbtClassMixin
from dbt_common.exceptions import DbtInternalError, CompilationError
Expand Down Expand Up @@ -94,11 +93,6 @@ class CloneTask(GraphRunnableTask):
def raise_on_first_error(self):
return False

def _get_deferred_manifest(self) -> Optional[WritableManifest]:
# Unlike other commands, 'clone' always requires a state manifest
# Load previous state, regardless of whether --defer flag has been set
return self._get_previous_state()

def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRelation]:
if self.manifest is None:
raise DbtInternalError("manifest was None in get_model_schemas")
Expand All @@ -122,8 +116,6 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe

def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
# unlike in other tasks, we want to add information from the --state manifest *before* caching!
self.defer_to_manifest(adapter, selected_uids)
# only create *our* schemas, but cache *other* schemas in addition
schemas_to_create = super().get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, schemas_to_create)
Expand Down
1 change: 0 additions & 1 deletion core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,6 @@ def before_run(self, adapter, selected_uids: AbstractSet[str]) -> None:
required_schemas = self.get_model_schemas(adapter, selected_uids)
self.create_schemas(adapter, required_schemas)
self.populate_adapter_cache(adapter, required_schemas)
self.defer_to_manifest(adapter, selected_uids)
self.safe_run_hooks(adapter, RunHookType.Start, {})

def after_run(self, adapter, results) -> None:
Expand Down
45 changes: 2 additions & 43 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import dbt.utils
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.artifacts.schemas.results import NodeStatus, RunningStatus, RunStatus, BaseResult
from dbt.artifacts.schemas.run import RunExecutionResult, RunResult
Expand Down Expand Up @@ -71,8 +70,9 @@ def __init__(self, args, config, manifest) -> None:
self.job_queue: Optional[GraphQueue] = None
self.node_results: List[BaseResult] = []
self.num_nodes: int = 0
# TODO: if --defer is enabled, we have already loaded the "previous state" artifacts into memory
# can we check to see, and reuse them if so?
self.previous_state: Optional[PreviousState] = None
self.previous_defer_state: Optional[PreviousState] = None
self.run_count: int = 0
self.started_at: float = 0

Expand All @@ -83,13 +83,6 @@ def __init__(self, args, config, manifest) -> None:
project_root=Path(self.config.project_root),
)

if self.args.defer_state:
self.previous_defer_state = PreviousState(
state_path=self.args.defer_state,
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
)

def index_offset(self, value: int) -> int:
return value

Expand Down Expand Up @@ -127,25 +120,6 @@ def get_selection_spec(self) -> SelectionSpec:
def get_node_selector(self) -> NodeSelector:
raise NotImplementedError(f"get_node_selector not implemented for task {type(self)}")

def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]):
deferred_manifest = self._get_deferred_manifest()
if deferred_manifest is None:
return
if self.manifest is None:
raise DbtInternalError(
"Expected to defer to manifest, but there is no runtime manifest to defer from!"
)
self.manifest.merge_from_artifact(
adapter=adapter,
other=deferred_manifest,
selected=selected_uids,
favor_state=bool(self.args.favor_state),
)
# We're rewriting the manifest because it's been mutated during merge_from_artifact.
# This is to reflect which nodes had been deferred to (= replaced with) their counterparts.
if self.args.write_json:
write_manifest(self.manifest, self.config.project_target_path)

def get_graph_queue(self) -> GraphQueue:
selector = self.get_node_selector()
# Following uses self.selection_arg and self.exclusion_arg
Expand Down Expand Up @@ -452,7 +426,6 @@ def populate_adapter_cache(
def before_run(self, adapter, selected_uids: AbstractSet[str]):
with adapter.connection_named("master"):
self.populate_adapter_cache(adapter)
self.defer_to_manifest(adapter, selected_uids)

def after_run(self, adapter, results):
pass
Expand Down Expand Up @@ -636,17 +609,3 @@ def get_result(self, results, elapsed_time, generated_at):

def task_end_messages(self, results):
print_run_end_messages(results)

def _get_previous_state(self) -> Optional[WritableManifest]:
state = self.previous_defer_state or self.previous_state
if not state:
raise DbtRuntimeError(
"--state or --defer-state are required for deferral, but neither was provided"
)

if not state.manifest:
raise DbtRuntimeError(f'Could not find manifest in --state path: "{state}"')
return state.manifest

def _get_deferred_manifest(self) -> Optional[WritableManifest]:
return self._get_previous_state() if self.args.defer else None
5 changes: 0 additions & 5 deletions tests/functional/defer_state/test_defer_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import shutil
from copy import deepcopy
Expand Down Expand Up @@ -181,10 +180,6 @@ def test_run_and_defer(self, project, unique_schema, other_schema):
assert other_schema not in results[0].node.compiled_code
assert unique_schema in results[0].node.compiled_code

with open("target/manifest.json") as fp:
data = json.load(fp)
assert data["nodes"]["seed.test.seed"]["deferred"]

assert len(results) == 1


Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def test_event_codes(self):
core_types.MainReportVersion(version=""),
core_types.MainReportArgs(args={}),
core_types.MainTrackingUserState(user_state=""),
core_types.MergedFromState(num_merged=0, sample=[]),
core_types.MissingProfileTarget(profile_name="", target_name=""),
core_types.InvalidOptionYAML(option_name="vars"),
core_types.LogDbtProjectError(),
Expand Down
Loading

0 comments on commit 7d962ed

Please sign in to comment.