From fb10bb4aea548bbe11fe78e1b8452b3b6eaa626d Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 23 May 2024 11:06:04 -0400 Subject: [PATCH] Maximally parallelize dbt clone (#10129) --- .../unreleased/Features-20240522-000309.yaml | 7 ++ core/dbt/graph/queue.py | 11 ++- core/dbt/graph/selector.py | 4 +- core/dbt/task/clone.py | 5 +- core/dbt/task/runnable.py | 16 +++- tests/unit/contracts/graph/test_manifest.py | 15 +--- tests/unit/graph/test_queue.py | 47 ++++++++++++ tests/unit/task/test_runnable.py | 76 ++++++++++++++++++- tests/unit/utils/__init__.py | 15 ++++ 9 files changed, 172 insertions(+), 24 deletions(-) create mode 100644 .changes/unreleased/Features-20240522-000309.yaml create mode 100644 tests/unit/graph/test_queue.py diff --git a/.changes/unreleased/Features-20240522-000309.yaml b/.changes/unreleased/Features-20240522-000309.yaml new file mode 100644 index 00000000000..d02d3be3170 --- /dev/null +++ b/.changes/unreleased/Features-20240522-000309.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Maximally parallelize dbt clone + in clone command" +time: 2024-05-22T00:03:09.765977-04:00 +custom: + Author: michelleark + Issue: "7914" diff --git a/core/dbt/graph/queue.py b/core/dbt/graph/queue.py index cda16cd3126..18ea15ac773 100644 --- a/core/dbt/graph/queue.py +++ b/core/dbt/graph/queue.py @@ -25,8 +25,15 @@ class GraphQueue: the same time, as there is an unlocked race! """ - def __init__(self, graph: nx.DiGraph, manifest: Manifest, selected: Set[UniqueId]) -> None: - self.graph = graph + def __init__( + self, + graph: nx.DiGraph, + manifest: Manifest, + selected: Set[UniqueId], + preserve_edges: bool = True, + ) -> None: + # 'create_empty_copy' returns a copy of the graph G with all of the edges removed, and leaves nodes intact. + self.graph = graph if preserve_edges else nx.classes.function.create_empty_copy(graph) self.manifest = manifest self._selected = selected # store the queue as a priority queue. diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index 6df9f352729..0ca3842f926 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -319,7 +319,7 @@ def get_selected(self, spec: SelectionSpec) -> Set[UniqueId]: return filtered_nodes - def get_graph_queue(self, spec: SelectionSpec) -> GraphQueue: + def get_graph_queue(self, spec: SelectionSpec, preserve_edges: bool = True) -> GraphQueue: """Returns a queue over nodes in the graph that tracks progress of dependecies. """ @@ -330,7 +330,7 @@ def get_graph_queue(self, spec: SelectionSpec) -> GraphQueue: # Construct a new graph using the selected_nodes new_graph = self.full_graph.get_subset_graph(selected_nodes) # should we give a way here for consumers to mutate the graph? - return GraphQueue(new_graph.graph, self.manifest, selected_nodes) + return GraphQueue(new_graph.graph, self.manifest, selected_nodes, preserve_edges) class ResourceTypeSelector(NodeSelector): diff --git a/core/dbt/task/clone.py b/core/dbt/task/clone.py index 8edfdd02068..98ac7153653 100644 --- a/core/dbt/task/clone.py +++ b/core/dbt/task/clone.py @@ -10,7 +10,7 @@ from dbt.node_types import REFABLE_NODE_TYPES from dbt.task.base import BaseRunner, resource_types_from_args from dbt.task.run import _validate_materialization_relations_dict -from dbt.task.runnable import GraphRunnableTask +from dbt.task.runnable import GraphRunnableMode, GraphRunnableTask from dbt_common.dataclass_schema import dbtClassMixin from dbt_common.exceptions import CompilationError, DbtInternalError @@ -94,6 +94,9 @@ class CloneTask(GraphRunnableTask): def raise_on_first_error(self): return False + def get_run_mode(self) -> GraphRunnableMode: + return GraphRunnableMode.Independent + def _get_deferred_manifest(self) -> Optional[Manifest]: # Unlike other commands, 'clone' always requires a state manifest # Load previous state, regardless of whether --defer flag has been set diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 639e02ec404..a01e7a06c22 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -48,6 +48,7 @@ from dbt.parser.manifest import write_manifest from dbt.task.base import BaseRunner, ConfiguredTask from dbt_common.context import _INVOCATION_CONTEXT_VAR, get_invocation_context +from dbt_common.dataclass_schema import StrEnum from dbt_common.events.contextvars import log_contextvars, task_contextvars from dbt_common.events.functions import fire_event, warn_or_error from dbt_common.events.types import Formatting @@ -58,6 +59,11 @@ RESULT_FILE_NAME = "run_results.json" +class GraphRunnableMode(StrEnum): + Topological = "topological" + Independent = "independent" + + class GraphRunnableTask(ConfiguredTask): MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error] @@ -135,7 +141,15 @@ def get_graph_queue(self) -> GraphQueue: selector = self.get_node_selector() # Following uses self.selection_arg and self.exclusion_arg spec = self.get_selection_spec() - return selector.get_graph_queue(spec) + + preserve_edges = True + if self.get_run_mode() == GraphRunnableMode.Independent: + preserve_edges = False + + return selector.get_graph_queue(spec, preserve_edges) + + def get_run_mode(self) -> GraphRunnableMode: + return GraphRunnableMode.Topological def _runtime_initialize(self): self.compile_manifest() diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 4d19a269f71..a945e7b5d3e 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -50,6 +50,7 @@ MockNode, MockSource, inject_plugin, + make_manifest, ) REQUIRED_PARSED_NODE_KEYS = frozenset( @@ -1091,20 +1092,6 @@ def setUp(self): ) -def make_manifest(nodes=[], sources=[], macros=[], docs=[]): - return Manifest( - nodes={n.unique_id: n for n in nodes}, - macros={m.unique_id: m for m in macros}, - sources={s.unique_id: s for s in sources}, - docs={d.unique_id: d for d in docs}, - disabled={}, - files={}, - exposures={}, - metrics={}, - selectors={}, - ) - - FindMacroSpec = namedtuple("FindMacroSpec", "macros,expected") macro_parameter_sets = [ diff --git a/tests/unit/graph/test_queue.py b/tests/unit/graph/test_queue.py new file mode 100644 index 00000000000..50671d03fb2 --- /dev/null +++ b/tests/unit/graph/test_queue.py @@ -0,0 +1,47 @@ +import networkx as nx +import pytest + +from dbt.contracts.graph.manifest import Manifest +from dbt.graph.queue import GraphQueue +from tests.unit.utils import MockNode, make_manifest + + +class TestGraphQueue: + @pytest.fixture(scope="class") + def manifest(self) -> Manifest: + return make_manifest( + nodes=[ + MockNode(package="test_package", name="upstream_model"), + MockNode(package="test_package", name="downstream_model"), + ] + ) + + @pytest.fixture(scope="class") + def graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + graph.add_edge("model.test_package.upstream_model", "model.test_package.downstream_model") + return graph + + def test_init_graph_queue(self, manifest, graph): + graph_queue = GraphQueue(graph=graph, manifest=manifest, selected={}) + + assert graph_queue.manifest == manifest + assert graph_queue.graph == graph + assert graph_queue.inner.queue == [(0, "model.test_package.upstream_model")] + assert graph_queue.in_progress == set() + assert graph_queue.queued == {"model.test_package.upstream_model"} + assert graph_queue.lock + + def test_init_graph_queue_preserve_edges_false(self, manifest, graph): + graph_queue = GraphQueue(graph=graph, manifest=manifest, selected={}, preserve_edges=False) + + # when preserve_edges is set to false, dependencies between nodes are no longer tracked in the priority queue + assert list(graph_queue.graph.edges) == [] + assert graph_queue.inner.queue == [ + (0, "model.test_package.downstream_model"), + (0, "model.test_package.upstream_model"), + ] + assert graph_queue.queued == { + "model.test_package.upstream_model", + "model.test_package.downstream_model", + } diff --git a/tests/unit/task/test_runnable.py b/tests/unit/task/test_runnable.py index a16e627a1bb..17e09830892 100644 --- a/tests/unit/task/test_runnable.py +++ b/tests/unit/task/test_runnable.py @@ -1,10 +1,14 @@ from dataclasses import dataclass -from typing import AbstractSet, Any, Dict, Optional +from typing import AbstractSet, Any, Dict, List, Optional, Tuple +import networkx as nx import pytest -from dbt.task.runnable import GraphRunnableTask +from dbt.artifacts.resources.types import NodeType +from dbt.graph import Graph, ResourceTypeSelector +from dbt.task.runnable import GraphRunnableMode, GraphRunnableTask from dbt.tests.util import safe_set_invocation_context +from tests.unit.utils import MockNode, make_manifest @dataclass @@ -14,6 +18,9 @@ class MockArgs: state: Optional[Dict[str, Any]] = None defer_state: Optional[Dict[str, Any]] = None write_json: bool = False + selector: Optional[str] = None + select: Tuple[str] = () + exclude: Tuple[str] = () @dataclass @@ -23,12 +30,28 @@ class MockConfig: threads: int = 1 target_name: str = "mock_config_target_name" + def get_default_selector_name(self): + return None + class MockRunnableTask(GraphRunnableTask): - def __init__(self, exception_class: Exception = Exception): + def __init__( + self, + exception_class: Exception = Exception, + nodes: Optional[List[MockNode]] = None, + edges: Optional[List[Tuple[str, str]]] = None, + ): + nodes = nodes or [] + edges = edges or [] + self.forced_exception_class = exception_class self.did_cancel: bool = False super().__init__(args=MockArgs(), config=MockConfig(), manifest=None) + self.manifest = make_manifest(nodes=nodes) + digraph = nx.DiGraph() + for edge in edges: + digraph.add_edge(edge[0], edge[1]) + self.graph = Graph(digraph) def run_queue(self, pool): """Override `run_queue` to raise a system exit""" @@ -40,13 +63,25 @@ def _cancel_connections(self, pool): def get_node_selector(self): """This is an `abstract_method` on `GraphRunnableTask`, thus we must implement it""" - return None + selector = ResourceTypeSelector( + graph=self.graph, + manifest=self.manifest, + previous_state=self.previous_state, + resource_types=[NodeType.Model], + include_empty_nodes=True, + ) + return selector def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): """This is an `abstract_method` on `GraphRunnableTask`, thus we must implement it""" return None +class MockRunnableTaskIndependent(MockRunnableTask): + def get_run_mode(self) -> GraphRunnableMode: + return GraphRunnableMode.Independent + + def test_graph_runnable_task_cancels_connection_on_system_exit(): safe_set_invocation_context() @@ -81,3 +116,36 @@ def test_graph_runnable_task_doesnt_cancel_connection_on_generic_exception(): # If `did_cancel` is True, that means `_cancel_connections` was called assert task.did_cancel is False + + +def test_graph_runnable_preserves_edges_by_default(): + task = MockRunnableTask( + nodes=[ + MockNode("test", "upstream_node", fqn="model.test.upstream_node"), + MockNode("test", "downstream_node", fqn="model.test.downstream_node"), + ], + edges=[("model.test.upstream_node", "model.test.downstream_node")], + ) + assert task.get_run_mode() == GraphRunnableMode.Topological + graph_queue = task.get_graph_queue() + + assert graph_queue.queued == {"model.test.upstream_node"} + assert graph_queue.inner.queue == [(0, "model.test.upstream_node")] + + +def test_graph_runnable_preserves_edges_false(): + task = MockRunnableTaskIndependent( + nodes=[ + MockNode("test", "upstream_node", fqn="model.test.upstream_node"), + MockNode("test", "downstream_node", fqn="model.test.downstream_node"), + ], + edges=[("model.test.upstream_node", "model.test.downstream_node")], + ) + assert task.get_run_mode() == GraphRunnableMode.Independent + graph_queue = task.get_graph_queue() + + assert graph_queue.queued == {"model.test.downstream_node", "model.test.upstream_node"} + assert graph_queue.inner.queue == [ + (0, "model.test.downstream_node"), + (0, "model.test.upstream_node"), + ] diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py index 7df68c949cc..411ad6ae756 100644 --- a/tests/unit/utils/__init__.py +++ b/tests/unit/utils/__init__.py @@ -11,6 +11,7 @@ import pytest from dbt.config.project import PartialProject +from dbt.contracts.graph.manifest import Manifest from dbt_common.dataclass_schema import ValidationError @@ -387,3 +388,17 @@ def replace_config(n, **kwargs): config=n.config.replace(**kwargs), unrendered_config=dict_replace(n.unrendered_config, **kwargs), ) + + +def make_manifest(nodes=[], sources=[], macros=[], docs=[]) -> Manifest: + return Manifest( + nodes={n.unique_id: n for n in nodes}, + macros={m.unique_id: m for m in macros}, + sources={s.unique_id: s for s in sources}, + docs={d.unique_id: d for d in docs}, + disabled={}, + files={}, + exposures={}, + metrics={}, + selectors={}, + )