diff --git a/.changes/unreleased/Under the Hood-20240226-184258.yaml b/.changes/unreleased/Under the Hood-20240226-184258.yaml new file mode 100644 index 00000000000..06c0f5e029a --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240226-184258.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Use Manifest instead of WritableManifest in PreviousState and _get_deferred_manifest +time: 2024-02-26T18:42:58.740808-05:00 +custom: + Author: michelleark + Issue: "9567" diff --git a/core/dbt/artifacts/resources/v1/components.py b/core/dbt/artifacts/resources/v1/components.py index 6a131ef761d..27509fb6072 100644 --- a/core/dbt/artifacts/resources/v1/components.py +++ b/core/dbt/artifacts/resources/v1/components.py @@ -206,3 +206,17 @@ class CompiledResource(ParsedResource): extra_ctes: List[InjectedCTE] = field(default_factory=list) _pre_injected_sql: Optional[str] = None contract: Contract = field(default_factory=Contract) + + def __post_serialize__(self, dct): + dct = super().__post_serialize__(dct) + if "_pre_injected_sql" in dct: + del dct["_pre_injected_sql"] + # Remove compiled attributes + if "compiled" in dct and dct["compiled"] is False: + del dct["compiled"] + del dct["extra_ctes_injected"] + del dct["extra_ctes"] + # "omit_none" means these might not be in the dictionary + if "compiled_code" in dct: + del dct["compiled_code"] + return dct diff --git a/core/dbt/artifacts/schemas/manifest/v12/manifest.py b/core/dbt/artifacts/schemas/manifest/v12/manifest.py index 66b5bb7b9e3..2ac3f3d761c 100644 --- a/core/dbt/artifacts/schemas/manifest/v12/manifest.py +++ b/core/dbt/artifacts/schemas/manifest/v12/manifest.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Mapping, Iterable, Tuple, Optional, Dict, List, Any +from typing import Mapping, Iterable, Tuple, Optional, Dict, List, Any, Union from uuid import UUID from dbt.artifacts.schemas.base import ( @@ -19,17 +19,38 @@ SemanticModel, SourceDefinition, UnitTestDefinition, -) - -# TODO: remove usage of dbt modules other than dbt.artifacts -from dbt.contracts.graph.nodes import ( - GraphMemberNode, - ManifestNode, + Seed, + Analysis, + SingularTest, + HookNode, + Model, + SqlOperation, + GenericTest, + Snapshot, ) NodeEdgeMap = Dict[str, List[str]] UniqueID = str +ManifestResource = Union[ + Seed, + Analysis, + SingularTest, + HookNode, + Model, + SqlOperation, + GenericTest, + Snapshot, +] +DisabledManifestResource = Union[ + ManifestResource, + SourceDefinition, + Exposure, + Metric, + SavedQuery, + SemanticModel, + UnitTestDefinition, +] @dataclass @@ -78,7 +99,7 @@ def default(cls): @dataclass @schema_version("manifest", 12) class WritableManifest(ArtifactMixin): - nodes: Mapping[UniqueID, ManifestNode] = field( + nodes: Mapping[UniqueID, ManifestResource] = field( metadata=dict(description=("The nodes defined in the dbt project and its dependencies")) ) sources: Mapping[UniqueID, SourceDefinition] = field( @@ -104,7 +125,7 @@ class WritableManifest(ArtifactMixin): selectors: Mapping[UniqueID, Any] = field( metadata=dict(description=("The selectors defined in selectors.yml")) ) - disabled: Optional[Mapping[UniqueID, List[GraphMemberNode]]] = field( + disabled: Optional[Mapping[UniqueID, List[DisabledManifestResource]]] = field( metadata=dict(description="A mapping of the disabled nodes in the target") ) parent_map: Optional[NodeEdgeMap] = field( diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index a952fef1e3f..852f4dce724 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -42,12 +42,17 @@ UnpatchedSourceDefinition, UnitTestDefinition, UnitTestFileFixture, + RESOURCE_CLASS_TO_NODE_CLASS, ) from dbt.contracts.graph.unparsed import SourcePatch, UnparsedVersion from dbt.flags import get_flags # to preserve import paths -from dbt.artifacts.resources import NodeVersion, DeferRelation +from dbt.artifacts.resources import ( + NodeVersion, + DeferRelation, + BaseResource, +) from dbt.artifacts.schemas.manifest import WritableManifest, ManifestMetadata, UniqueID from dbt.contracts.files import ( SourceFile, @@ -774,6 +779,7 @@ class ManifestStateCheck(dbtClassMixin): NodeClassT = TypeVar("NodeClassT", bound="BaseNode") +ResourceClassT = TypeVar("ResourceClassT", bound="BaseResource") @dataclass @@ -1029,16 +1035,66 @@ def fill_tracking_metadata(self): self.metadata.send_anonymous_usage_stats = get_flags().SEND_ANONYMOUS_USAGE_STATS @classmethod + def from_writable_manifest(cls, writable_manifest: WritableManifest) -> "Manifest": + manifest = Manifest( + nodes=cls._map_resources_to_map_nodes(writable_manifest.nodes), + disabled=cls._map_list_resources_to_map_list_nodes(writable_manifest.disabled), + unit_tests=cls._map_resources_to_map_nodes(writable_manifest.unit_tests), + sources=cls._map_resources_to_map_nodes(writable_manifest.sources), + macros=cls._map_resources_to_map_nodes(writable_manifest.macros), + docs=cls._map_resources_to_map_nodes(writable_manifest.docs), + exposures=cls._map_resources_to_map_nodes(writable_manifest.exposures), + metrics=cls._map_resources_to_map_nodes(writable_manifest.metrics), + groups=cls._map_resources_to_map_nodes(writable_manifest.groups), + semantic_models=cls._map_resources_to_map_nodes(writable_manifest.semantic_models), + selectors={ + selector_id: selector + for selector_id, selector in writable_manifest.selectors.items() + }, + ) + + return manifest + def _map_nodes_to_map_resources(cls, nodes_map: MutableMapping[str, NodeClassT]): return {node_id: node.to_resource() for node_id, node in nodes_map.items()} + def _map_list_nodes_to_map_list_resources( + cls, nodes_map: MutableMapping[str, List[NodeClassT]] + ): + return { + node_id: [node.to_resource() for node in node_list] + for node_id, node_list in nodes_map.items() + } + + @classmethod + def _map_resources_to_map_nodes(cls, resources_map: Mapping[str, ResourceClassT]): + return { + node_id: RESOURCE_CLASS_TO_NODE_CLASS[type(resource)].from_resource(resource) + for node_id, resource in resources_map.items() + } + + @classmethod + def _map_list_resources_to_map_list_nodes( + cls, resources_map: Optional[Mapping[str, List[ResourceClassT]]] + ): + if resources_map is None: + return {} + + return { + node_id: [ + RESOURCE_CLASS_TO_NODE_CLASS[type(resource)].from_resource(resource) + for resource in resource_list + ] + for node_id, resource_list in resources_map.items() + } + def writable_manifest(self) -> "WritableManifest": self.build_parent_and_child_maps() self.build_group_map() self.fill_tracking_metadata() return WritableManifest( - nodes=self.nodes, + nodes=self._map_nodes_to_map_resources(self.nodes), sources=self._map_nodes_to_map_resources(self.sources), macros=self._map_nodes_to_map_resources(self.macros), docs=self._map_nodes_to_map_resources(self.docs), @@ -1047,7 +1103,7 @@ def writable_manifest(self) -> "WritableManifest": groups=self._map_nodes_to_map_resources(self.groups), selectors=self.selectors, metadata=self.metadata, - disabled=self.disabled, + disabled=self._map_list_nodes_to_map_list_resources(self.disabled), child_map=self.child_map, parent_map=self.parent_map, group_map=self.group_map, @@ -1369,7 +1425,7 @@ def is_invalid_protected_ref( def merge_from_artifact( self, adapter, - other: "WritableManifest", + other: "Manifest", selected: AbstractSet[UniqueID], favor_state: bool = False, ) -> None: diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index e963eb14a60..e6f951d43fd 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -15,6 +15,7 @@ Type, Iterator, Literal, + get_args, ) from dbt import deprecations @@ -396,20 +397,6 @@ def set_cte(self, cte_id: str, sql: str): else: self.extra_ctes.append(InjectedCTE(id=cte_id, sql=sql)) - def __post_serialize__(self, dct): - dct = super().__post_serialize__(dct) - if "_pre_injected_sql" in dct: - del dct["_pre_injected_sql"] - # Remove compiled attributes - if "compiled" in dct and dct["compiled"] is False: - del dct["compiled"] - del dct["extra_ctes_injected"] - del dct["extra_ctes"] - # "omit_none" means these might not be in the dictionary - if "compiled_code" in dct: - del dct["compiled_code"] - return dct - @property def depends_on_nodes(self): return self.depends_on.nodes @@ -426,16 +413,24 @@ def depends_on_macros(self): @dataclass class AnalysisNode(AnalysisResource, CompiledNode): - pass + @classmethod + def resource_class(cls) -> Type[AnalysisResource]: + return AnalysisResource @dataclass class HookNode(HookNodeResource, CompiledNode): - pass + @classmethod + def resource_class(cls) -> Type[HookNodeResource]: + return HookNodeResource @dataclass class ModelNode(ModelResource, CompiledNode): + @classmethod + def resource_class(cls) -> Type[ModelResource]: + return ModelResource + @classmethod def from_args(cls, args: ModelNodeArgs) -> "ModelNode": unique_id = args.unique_id @@ -768,7 +763,9 @@ def same_contract(self, old, adapter_type=None) -> bool: @dataclass class SqlNode(SqlOperationResource, CompiledNode): - pass + @classmethod + def resource_class(cls) -> Type[SqlOperationResource]: + return SqlOperationResource # ==================================== @@ -778,6 +775,10 @@ class SqlNode(SqlOperationResource, CompiledNode): @dataclass class SeedNode(SeedResource, ParsedNode): # No SQLDefaults! + @classmethod + def resource_class(cls) -> Type[SeedResource]: + return SeedResource + def same_seeds(self, other: "SeedNode") -> bool: # for seeds, we check the hashes. If the hashes are different types, # no match. If the hashes are both the same 'path', log a warning and @@ -896,6 +897,10 @@ def is_relational(self): @dataclass class SingularTestNode(SingularTestResource, TestShouldStoreFailures, CompiledNode): + @classmethod + def resource_class(cls) -> Type[SingularTestResource]: + return SingularTestResource + @property def test_node_type(self): return "singular" @@ -908,6 +913,10 @@ def test_node_type(self): @dataclass class GenericTestNode(GenericTestResource, TestShouldStoreFailures, CompiledNode): + @classmethod + def resource_class(cls) -> Type[GenericTestResource]: + return GenericTestResource + def same_contents(self, other, adapter_type: Optional[str]) -> bool: if other is None: return False @@ -1014,7 +1023,9 @@ class IntermediateSnapshotNode(CompiledNode): @dataclass class SnapshotNode(SnapshotResource, CompiledNode): - pass + @classmethod + def resource_class(cls) -> Type[SnapshotResource]: + return SnapshotResource # ==================================== @@ -1626,3 +1637,10 @@ class ParsedMacroPatch(ParsedPatch): ] TestNode = Union[SingularTestNode, GenericTestNode] + + +RESOURCE_CLASS_TO_NODE_CLASS: Dict[Type[BaseResource], Type[BaseNode]] = { + node_class.resource_class(): node_class + for node_class in get_args(Resource) + if node_class is not UnitTestNode +} diff --git a/core/dbt/contracts/state.py b/core/dbt/contracts/state.py index 9111e2dfb46..16683f8d899 100644 --- a/core/dbt/contracts/state.py +++ b/core/dbt/contracts/state.py @@ -1,7 +1,8 @@ from pathlib import Path from typing import Optional -from dbt.contracts.graph.manifest import WritableManifest +from dbt.contracts.graph.manifest import Manifest +from dbt.artifacts.schemas.manifest import WritableManifest from dbt.artifacts.schemas.freshness import FreshnessExecutionResultArtifact from dbt.artifacts.schemas.run import RunResultsArtifact from dbt_common.events.functions import fire_event @@ -24,7 +25,7 @@ def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> N self.state_path: Path = state_path self.target_path: Path = target_path self.project_root: Path = project_root - self.manifest: Optional[WritableManifest] = None + self.manifest: Optional[Manifest] = None self.results: Optional[RunResultsArtifact] = None self.sources: Optional[FreshnessExecutionResultArtifact] = None self.sources_current: Optional[FreshnessExecutionResultArtifact] = None @@ -36,7 +37,8 @@ def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> N manifest_path = self.project_root / self.state_path / "manifest.json" if manifest_path.exists() and manifest_path.is_file(): try: - self.manifest = WritableManifest.read_and_check_versions(str(manifest_path)) + writable_manifest = WritableManifest.read_and_check_versions(str(manifest_path)) + self.manifest = Manifest.from_writable_manifest(writable_manifest) except IncompatibleSchemaError as exc: exc.add_filename(str(manifest_path)) raise diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index 0a4a99231d0..a9cc6eabbbb 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -8,7 +8,7 @@ from .graph import UniqueId -from dbt.contracts.graph.manifest import Manifest, WritableManifest +from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ( SingularTestNode, Exposure, @@ -725,7 +725,7 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu f'Got an invalid selector "{selector}", expected one of ' f'"{list(state_checks)}"' ) - manifest: WritableManifest = self.previous_state.manifest + manifest: Manifest = self.previous_state.manifest for unique_id, node in self.all_nodes(included_nodes): previous_node: Optional[SelectorTarget] = None diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index 49f7d857a30..53c322211cb 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -4,7 +4,7 @@ from dbt.adapters.base import BaseRelation from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context -from dbt.contracts.graph.manifest import WritableManifest +from dbt.contracts.graph.manifest import Manifest from dbt.artifacts.schemas.run import RunStatus, RunResult from dbt_common.dataclass_schema import dbtClassMixin from dbt_common.exceptions import DbtInternalError, CompilationError @@ -94,7 +94,7 @@ class CloneTask(GraphRunnableTask): def raise_on_first_error(self): return False - def _get_deferred_manifest(self) -> Optional[WritableManifest]: + def _get_deferred_manifest(self) -> Optional[Manifest]: # Unlike other commands, 'clone' always requires a state manifest # Load previous state, regardless of whether --defer flag has been set return self._get_previous_state() diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 8b361988517..d44f88cfc0b 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -15,7 +15,7 @@ import dbt.utils from dbt.adapters.base import BaseRelation from dbt.adapters.factory import get_adapter -from dbt.contracts.graph.manifest import WritableManifest +from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ResultNode from dbt.artifacts.schemas.results import NodeStatus, RunningStatus, RunStatus, BaseResult from dbt.artifacts.schemas.run import RunExecutionResult, RunResult @@ -643,7 +643,7 @@ def get_result(self, results, elapsed_time, generated_at): def task_end_messages(self, results): print_run_end_messages(results) - def _get_previous_state(self) -> Optional[WritableManifest]: + def _get_previous_state(self) -> Optional[Manifest]: state = self.previous_defer_state or self.previous_state if not state: raise DbtRuntimeError( @@ -654,5 +654,5 @@ def _get_previous_state(self) -> Optional[WritableManifest]: raise DbtRuntimeError(f'Could not find manifest in --state path: "{state}"') return state.manifest - def _get_deferred_manifest(self) -> Optional[WritableManifest]: + def _get_deferred_manifest(self) -> Optional[Manifest]: return self._get_previous_state() if self.args.defer else None diff --git a/tests/unit/test_graph_selector_methods.py b/tests/unit/test_graph_selector_methods.py index af1dc6fde3e..ee2c2c4d968 100644 --- a/tests/unit/test_graph_selector_methods.py +++ b/tests/unit/test_graph_selector_methods.py @@ -1,3 +1,4 @@ +from argparse import Namespace import copy from dataclasses import replace import pytest @@ -35,6 +36,7 @@ MacroDependsOn, TestConfig, TestMetadata, + RefArgs, ) from dbt.contracts.graph.unparsed import ( UnitTestInputFixture, @@ -65,6 +67,9 @@ import dbt_common.exceptions from dbt_semantic_interfaces.type_enums import MetricType from .utils import replace_config +from dbt.flags import set_from_args + +set_from_args(Namespace(WARN_ERROR=False), None) def make_model( @@ -109,7 +114,8 @@ def make_model( source_values = [] ref_values = [] for ref in refs: - ref_values.append([ref.name]) + ref_version = ref.version if hasattr(ref, "version") else None + ref_values.append(RefArgs(name=ref.name, package=ref.package_name, version=ref_version)) depends_on_nodes.append(ref.unique_id) for src in sources: source_values.append([src.source_name, src.name]) @@ -261,7 +267,11 @@ def make_generic_test( source_values.append([test_model.source_name, test_model.name]) else: kwargs["model"] = "{{ ref('" + test_model.name + "')}}" - ref_values.append([test_model.name]) + ref_values.append( + RefArgs( + name=test_model.name, package=test_model.package_name, version=test_model.version + ) + ) if column_name is not None: kwargs["column_name"] = column_name @@ -296,7 +306,8 @@ def make_generic_test( depends_on_nodes = [] for ref in refs: - ref_values.append([ref.name]) + ref_version = ref.version if hasattr(ref, "version") else None + ref_values.append(RefArgs(name=ref.name, package=ref.package_name, version=ref_version)) depends_on_nodes.append(ref.unique_id) for source in sources: @@ -378,7 +389,8 @@ def make_singular_test( source_values = [] ref_values = [] for ref in refs: - ref_values.append([ref.name]) + ref_version = ref.version if hasattr(ref, "version") else None + ref_values.append(RefArgs(name=ref.name, package=ref.package_name, version=ref_version)) depends_on_nodes.append(ref.unique_id) for src in sources: source_values.append([src.source_name, src.name]) @@ -903,7 +915,7 @@ def manifest( files={}, exposures={}, metrics={}, - disabled=[], + disabled={}, selectors={}, groups={}, metadata=ManifestMetadata(adapter_type="postgres"), @@ -1434,7 +1446,7 @@ def previous_state(manifest): target_path=Path("/path/does/not/exist"), project_root=Path("/path/does/not/exist"), ) - state.manifest = writable + state.manifest = Manifest.from_writable_manifest(writable) return state diff --git a/tests/unit/test_manifest.py b/tests/unit/test_manifest.py index 55d655d0ba4..ea443d1147f 100644 --- a/tests/unit/test_manifest.py +++ b/tests/unit/test_manifest.py @@ -1042,7 +1042,7 @@ def test_merge_from_artifact(self): original_manifest = Manifest(nodes=original_nodes) other_manifest = Manifest(nodes=other_nodes) adapter = mock.MagicMock() - original_manifest.merge_from_artifact(adapter, other_manifest.writable_manifest(), {}) + original_manifest.merge_from_artifact(adapter, other_manifest, {}) # new node added should not be in original manifest assert "model.root.nested2" not in original_manifest.nodes