From fa16a8c0abf412fba69044281f6a21e0a4a9f1d0 Mon Sep 17 00:00:00 2001 From: Jeremy Cohen Date: Mon, 4 Dec 2023 04:59:09 +0100 Subject: [PATCH] Move deferral from task to manifest loading + RefResolver --- core/dbt/cli/requires.py | 30 ++++++++++++- core/dbt/context/providers.py | 21 ++++++++- core/dbt/contracts/graph/manifest.py | 44 ++++--------------- core/dbt/events/types.py | 7 +-- core/dbt/task/clone.py | 10 +---- core/dbt/task/run.py | 1 - core/dbt/task/runnable.py | 41 ----------------- .../defer_state/test_defer_state.py | 5 --- tests/unit/test_events.py | 1 - tests/unit/test_manifest.py | 5 +-- 10 files changed, 60 insertions(+), 105 deletions(-) diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index b8e359b8a17..35e184513fb 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -9,6 +9,7 @@ from dbt.cli.flags import Flags from dbt.config import RuntimeConfig from dbt.config.runtime import load_project, load_profile, UnsetProfile +from dbt.contracts.state import PreviousState from dbt.events.base_types import EventLevel from dbt.events.functions import fire_event, LOG_VERSION, set_invocation_id, setup_event_logger from dbt.events.types import ( @@ -20,7 +21,12 @@ ) from dbt.events.helpers import get_json_string_utcnow from dbt.events.types import MainEncounteredError, MainStackTrace -from dbt.exceptions import Exception as DbtException, DbtProjectError, FailFastError +from dbt.exceptions import ( + Exception as DbtException, + DbtProjectError, + DbtRuntimeError, + FailFastError, +) from dbt.parser.manifest import ManifestLoader, write_manifest from dbt.profiler import profiler from dbt.tracking import active_user, initialize_from_flags, track_run @@ -30,6 +36,7 @@ from click import Context from functools import update_wrapper import importlib.util +from pathlib import Path import time import traceback @@ -265,6 +272,7 @@ def wrapper(*args, **kwargs): runtime_config = ctx.obj["runtime_config"] register_adapter(runtime_config) + flags: Flags = ctx.obj["flags"] # a manifest has already been set on the context, so don't overwrite it if ctx.obj.get("manifest") is None: @@ -273,8 +281,26 @@ def wrapper(*args, **kwargs): write_perf_info=write_perf_info, ) + # If deferral is enabled, add 'defer_relation' attribute to all nodes + if flags.defer: + defer_state = flags.defer_state or flags.state + if not defer_state: + raise DbtRuntimeError( + "Deferral is enabled and requires a stateful manifest, but none was provided" + ) + previous_state = PreviousState( + state_path=defer_state, + target_path=Path(runtime_config.target_path), + project_root=Path(runtime_config.project_root), + ) + if not previous_state.manifest: + raise DbtRuntimeError( + f'Could not find manifest in deferral state path: "{defer_state}"' + ) + manifest.merge_from_artifact(previous_state.manifest) + ctx.obj["manifest"] = manifest - if write and ctx.obj["flags"].write_json: + if write and flags.write_json: write_manifest(manifest, runtime_config.project_target_path) pm = get_plugin_manager(runtime_config.project_name) plugin_artifacts = pm.get_manifest_artifacts(manifest) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 782e9f9622a..dcfc9fad3e6 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -501,6 +501,7 @@ def resolve( self.model.package_name, ) + # Raise an error if the reference target is missing if target_model is None or isinstance(target_model, Disabled): raise TargetNotFoundError( node=self.model, @@ -510,6 +511,8 @@ def resolve( target_version=target_version, disabled=isinstance(target_model, Disabled), ) + + # Raise error if trying to reference a 'private' resource outside its 'group' elif self.manifest.is_invalid_private_ref( self.model, target_model, self.config.dependencies ): @@ -519,6 +522,7 @@ def resolve( access=AccessType.Private, scope=cast_to_str(target_model.group), ) + # Or a 'protected' resource outside its project/package namespace elif self.manifest.is_invalid_protected_ref( self.model, target_model, self.config.dependencies ): @@ -528,7 +532,6 @@ def resolve( access=AccessType.Protected, scope=target_model.package_name, ) - self.validate(target_model, target_name, target_package, target_version) return self.create_relation(target_model) @@ -538,6 +541,22 @@ def create_relation(self, target_model: ManifestNode) -> RelationProxy: return self.Relation.create_ephemeral_from_node( self.config, target_model, limit=self.resolve_limit ) + elif ( + hasattr(target_model, "defer_relation") + and target_model.defer_relation + and self.config.args.defer + and ( + # User has explicitly opted to prefer defer_relation + self.config.args.favor_state + # Or, this node's relation does not exist in the expected target location (cache lookup) + or not get_adapter(self.config).get_relation( + target_model.database, target_model.schema, target_model.identifier + ) + ) + ): + return self.Relation.create_from_node( + self.config, target_model.defer_relation, limit=self.resolve_limit + ) else: return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index d0817cb4b36..e6281708281 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -1,7 +1,7 @@ import enum from collections import defaultdict from dataclasses import dataclass, field -from itertools import chain, islice +from itertools import chain from mashumaro.mixins.msgpack import DataClassMessagePackMixin from multiprocessing.synchronize import Lock from typing import ( @@ -18,7 +18,6 @@ TypeVar, Callable, Generic, - AbstractSet, ClassVar, Iterable, ) @@ -63,7 +62,7 @@ ) from dbt.helper_types import PathSet from dbt.events.functions import fire_event -from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable +from dbt.events.types import UnpinnedRefNewVersionAvailable from dbt.events.contextvars import get_node_info from dbt.node_types import NodeType, AccessType from dbt.flags import get_flags, MP_CONTEXT @@ -1340,50 +1339,25 @@ def is_invalid_protected_ref( node.package_name != target_model.package_name and restrict_package_access ) - # Called by GraphRunnableTask.defer_to_manifest - def merge_from_artifact( - self, - adapter, - other: "WritableManifest", - selected: AbstractSet[UniqueID], - favor_state: bool = False, - ) -> None: - """Given the selected unique IDs and a writable manifest, update this - manifest by replacing any unselected nodes with their counterpart. + # Called by requires.manifest after ManifestLoader.get_full_manifest + def merge_from_artifact(self, other: "WritableManifest") -> None: + """Update this manifest by adding the 'defer_relation' attribute to all nodes + with a counterpart in the stateful manifest used for deferral. Only non-ephemeral refable nodes are examined. """ refables = set(NodeType.refable()) - merged = set() for unique_id, node in other.nodes.items(): current = self.nodes.get(unique_id) - if current and ( - node.resource_type in refables - and not node.is_ephemeral - and unique_id not in selected - and ( - not adapter.get_relation(current.database, current.schema, current.identifier) - or favor_state - ) - ): - merged.add(unique_id) - self.nodes[unique_id] = node.replace(deferred=True) - - # for all other nodes, add 'defer_relation' - elif current and node.resource_type in refables and not node.is_ephemeral: + if current and node.resource_type in refables and not node.is_ephemeral: defer_relation = DeferRelation( node.database, node.schema, node.alias, node.relation_name ) self.nodes[unique_id] = current.replace(defer_relation=defer_relation) - # Rebuild the flat_graph, which powers the 'graph' context variable, - # now that we've deferred some nodes + # Rebuild the flat_graph, which powers the 'graph' context variable self.build_flat_graph() - # log up to 5 items - sample = list(islice(merged, 5)) - fire_event(MergedFromState(num_merged=len(merged), sample=sample)) - # Methods that were formerly in ParseResult def add_macro(self, source_file: SourceFile, macro: Macro): @@ -1621,8 +1595,6 @@ def __post_serialize__(self, dct): for unique_id, node in dct["nodes"].items(): if "config_call_dict" in node: del node["config_call_dict"] - if "defer_relation" in node: - del node["defer_relation"] return dct diff --git a/core/dbt/events/types.py b/core/dbt/events/types.py index 32ab0c3429f..e9d71d48caf 100644 --- a/core/dbt/events/types.py +++ b/core/dbt/events/types.py @@ -75,12 +75,7 @@ def message(self) -> str: return f"Tracking: {self.user_state}" -class MergedFromState(DebugLevel): - def code(self) -> str: - return "A004" - - def message(self) -> str: - return f"Merged {self.num_merged} items from state (sample: {self.sample})" +# Removed A004: MergedFromState class MissingProfileTarget(InfoLevel): diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index fbbd8583b67..0553a7250f7 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -1,10 +1,9 @@ import threading -from typing import AbstractSet, Any, List, Iterable, Set, Optional +from typing import AbstractSet, Any, List, Iterable, Set from dbt.adapters.base import BaseRelation from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context -from dbt.contracts.graph.manifest import WritableManifest from dbt.contracts.results import RunStatus, RunResult from dbt.dataclass_schema import dbtClassMixin from dbt.exceptions import DbtInternalError, CompilationError @@ -94,11 +93,6 @@ class CloneTask(GraphRunnableTask): def raise_on_first_error(self): return False - def _get_deferred_manifest(self) -> Optional[WritableManifest]: - # Unlike other commands, 'clone' always requires a state manifest - # Load previous state, regardless of whether --defer flag has been set - return self._get_previous_state() - def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRelation]: if self.manifest is None: raise DbtInternalError("manifest was None in get_model_schemas") @@ -122,8 +116,6 @@ def get_model_schemas(self, adapter, selected_uids: Iterable[str]) -> Set[BaseRe def before_run(self, adapter, selected_uids: AbstractSet[str]): with adapter.connection_named("master"): - # unlike in other tasks, we want to add information from the --state manifest *before* caching! - self.defer_to_manifest(adapter, selected_uids) # only create *our* schemas, but cache *other* schemas in addition schemas_to_create = super().get_model_schemas(adapter, selected_uids) self.create_schemas(adapter, schemas_to_create) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index a046c4b22e1..652714f8c79 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -445,7 +445,6 @@ def before_run(self, adapter, selected_uids: AbstractSet[str]) -> None: required_schemas = self.get_model_schemas(adapter, selected_uids) self.create_schemas(adapter, required_schemas) self.populate_adapter_cache(adapter, required_schemas) - self.defer_to_manifest(adapter, selected_uids) self.safe_run_hooks(adapter, RunHookType.Start, {}) def after_run(self, adapter, results) -> None: diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index c0005935a8d..01389c97fc4 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -12,7 +12,6 @@ import dbt.utils from dbt.adapters.base import BaseRelation from dbt.adapters.factory import get_adapter -from dbt.contracts.graph.manifest import WritableManifest from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.results import ( NodeStatus, @@ -76,7 +75,6 @@ def __init__(self, args, config, manifest) -> None: self.node_results: List[BaseResult] = [] self.num_nodes: int = 0 self.previous_state: Optional[PreviousState] = None - self.previous_defer_state: Optional[PreviousState] = None self.run_count: int = 0 self.started_at: float = 0 @@ -87,13 +85,6 @@ def __init__(self, args, config, manifest) -> None: project_root=Path(self.config.project_root), ) - if self.args.defer_state: - self.previous_defer_state = PreviousState( - state_path=self.args.defer_state, - target_path=Path(self.config.target_path), - project_root=Path(self.config.project_root), - ) - def index_offset(self, value: int) -> int: return value @@ -130,23 +121,6 @@ def get_selection_spec(self) -> SelectionSpec: def get_node_selector(self) -> NodeSelector: raise NotImplementedError(f"get_node_selector not implemented for task {type(self)}") - def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): - deferred_manifest = self._get_deferred_manifest() - if deferred_manifest is None: - return - if self.manifest is None: - raise DbtInternalError( - "Expected to defer to manifest, but there is no runtime manifest to defer from!" - ) - self.manifest.merge_from_artifact( - adapter=adapter, - other=deferred_manifest, - selected=selected_uids, - favor_state=bool(self.args.favor_state), - ) - # TODO: is it wrong to write the manifest here? I think it's right... - write_manifest(self.manifest, self.config.project_target_path) - def get_graph_queue(self) -> GraphQueue: selector = self.get_node_selector() spec = self.get_selection_spec() @@ -432,7 +406,6 @@ def populate_adapter_cache( def before_run(self, adapter, selected_uids: AbstractSet[str]): with adapter.connection_named("master"): self.populate_adapter_cache(adapter) - self.defer_to_manifest(adapter, selected_uids) def after_run(self, adapter, results): pass @@ -616,17 +589,3 @@ def get_result(self, results, elapsed_time, generated_at): def task_end_messages(self, results): print_run_end_messages(results) - - def _get_previous_state(self) -> Optional[WritableManifest]: - state = self.previous_defer_state or self.previous_state - if not state: - raise DbtRuntimeError( - "--state or --defer-state are required for deferral, but neither was provided" - ) - - if not state.manifest: - raise DbtRuntimeError(f'Could not find manifest in --state path: "{state}"') - return state.manifest - - def _get_deferred_manifest(self) -> Optional[WritableManifest]: - return self._get_previous_state() if self.args.defer else None diff --git a/tests/functional/defer_state/test_defer_state.py b/tests/functional/defer_state/test_defer_state.py index 102345fdf6e..49db771c80b 100644 --- a/tests/functional/defer_state/test_defer_state.py +++ b/tests/functional/defer_state/test_defer_state.py @@ -1,4 +1,3 @@ -import json import os import shutil from copy import deepcopy @@ -181,10 +180,6 @@ def test_run_and_defer(self, project, unique_schema, other_schema): assert other_schema not in results[0].node.compiled_code assert unique_schema in results[0].node.compiled_code - with open("target/manifest.json") as fp: - data = json.load(fp) - assert data["nodes"]["seed.test.seed"]["deferred"] - assert len(results) == 1 diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index 89dc1a15255..f6a18e9a2a7 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -117,7 +117,6 @@ def test_event_codes(self): types.MainReportVersion(version=""), types.MainReportArgs(args={}), types.MainTrackingUserState(user_state=""), - types.MergedFromState(num_merged=0, sample=[]), types.MissingProfileTarget(profile_name="", target_name=""), types.InvalidOptionYAML(option_name="vars"), types.LogDbtProjectError(), diff --git a/tests/unit/test_manifest.py b/tests/unit/test_manifest.py index 6b0be8dfcb7..948c60bf5be 100644 --- a/tests/unit/test_manifest.py +++ b/tests/unit/test_manifest.py @@ -1019,7 +1019,7 @@ def test_build_flat_graph(self): self.assertEqual(frozenset(node), REQUIRED_PARSED_NODE_KEYS) self.assertEqual(compiled_count, 2) - def test_add_from_artifact(self): + def test_merge_from_artifact(self): original_nodes = deepcopy(self.nested_nodes) other_nodes = deepcopy(self.nested_nodes) @@ -1041,8 +1041,7 @@ def test_add_from_artifact(self): original_manifest = Manifest(nodes=original_nodes) other_manifest = Manifest(nodes=other_nodes) - adapter = mock.MagicMock() - original_manifest.merge_from_artifact(adapter, other_manifest.writable_manifest(), {}) + original_manifest.merge_from_artifact(other_manifest.writable_manifest()) # new node added should not be in original manifest assert "model.root.nested2" not in original_manifest.nodes