diff --git a/tests/unit/README.md b/tests/unit/README.md index b9e22124c85..e063e370b48 100644 --- a/tests/unit/README.md +++ b/tests/unit/README.md @@ -1,6 +1,5 @@ # Unit test README -## test_contracts_graph_parsed.py ### The Why We need to ensure that we can go from objects to dictionaries and back without any @@ -16,3 +15,7 @@ versions of the object we're interested in testing, and run the different genera of the object through the test. This gives us confidence that for any allowable configuration of an object, state is not changed when moving back and forth betweeen the python object version and the seralized version. + +### The What + +- We test concrete classes in the codebase and do not test abstract classes as they are implementation details. [reference](https://enterprisecraftsmanship.com/posts/how-to-unit-test-an-abstract-class/) diff --git a/tests/unit/task/test_clone.py b/tests/unit/task/test_clone.py new file mode 100644 index 00000000000..9ef07d5e492 --- /dev/null +++ b/tests/unit/task/test_clone.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock, patch + +from dbt.flags import get_flags +from dbt.task.clone import CloneTask + + +def test_clone_task_not_preserve_edges(): + mock_node_selector = MagicMock() + mock_spec = MagicMock() + with patch.object( + CloneTask, "get_node_selector", return_value=mock_node_selector + ), patch.object(CloneTask, "get_selection_spec", return_value=mock_spec): + task = CloneTask(get_flags(), None, None) + task.get_graph_queue() + # when we get the graph queue, preserve_edges is False + mock_node_selector.get_graph_queue.assert_called_with(mock_spec, False) diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py new file mode 100644 index 00000000000..c689e8f41aa --- /dev/null +++ b/tests/unit/task/test_run.py @@ -0,0 +1,52 @@ +from argparse import Namespace +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.config.runtime import RuntimeConfig +from dbt.flags import get_flags, set_from_args +from dbt.task.run import RunTask +from dbt.tests.util import safe_set_invocation_context + + +@pytest.mark.parametrize( + "exception_to_raise, expected_cancel_connections", + [ + (SystemExit, True), + (KeyboardInterrupt, True), + (Exception, False), + ], +) +def test_run_task_cancel_connections( + exception_to_raise, expected_cancel_connections, runtime_config: RuntimeConfig +): + safe_set_invocation_context() + + def mock_run_queue(*args, **kwargs): + raise exception_to_raise("Test exception") + + with patch.object(RunTask, "run_queue", mock_run_queue), patch.object( + RunTask, "_cancel_connections" + ) as mock_cancel_connections: + + set_from_args(Namespace(write_json=False), None) + task = RunTask( + get_flags(), + runtime_config, + None, + ) + with pytest.raises(exception_to_raise): + task.execute_nodes() + assert mock_cancel_connections.called == expected_cancel_connections + + +def test_run_task_preserve_edges(): + mock_node_selector = MagicMock() + mock_spec = MagicMock() + with patch.object(RunTask, "get_node_selector", return_value=mock_node_selector), patch.object( + RunTask, "get_selection_spec", return_value=mock_spec + ): + task = RunTask(get_flags(), None, None) + task.get_graph_queue() + # when we get the graph queue, preserve_edges is True + mock_node_selector.get_graph_queue.assert_called_with(mock_spec, True) diff --git a/tests/unit/task/test_runnable.py b/tests/unit/task/test_runnable.py deleted file mode 100644 index 17e09830892..00000000000 --- a/tests/unit/task/test_runnable.py +++ /dev/null @@ -1,151 +0,0 @@ -from dataclasses import dataclass -from typing import AbstractSet, Any, Dict, List, Optional, Tuple - -import networkx as nx -import pytest - -from dbt.artifacts.resources.types import NodeType -from dbt.graph import Graph, ResourceTypeSelector -from dbt.task.runnable import GraphRunnableMode, GraphRunnableTask -from dbt.tests.util import safe_set_invocation_context -from tests.unit.utils import MockNode, make_manifest - - -@dataclass -class MockArgs: - """Simple mock args for us in a runnable task""" - - state: Optional[Dict[str, Any]] = None - defer_state: Optional[Dict[str, Any]] = None - write_json: bool = False - selector: Optional[str] = None - select: Tuple[str] = () - exclude: Tuple[str] = () - - -@dataclass -class MockConfig: - """Simple mock config for use in a RunnableTask""" - - threads: int = 1 - target_name: str = "mock_config_target_name" - - def get_default_selector_name(self): - return None - - -class MockRunnableTask(GraphRunnableTask): - def __init__( - self, - exception_class: Exception = Exception, - nodes: Optional[List[MockNode]] = None, - edges: Optional[List[Tuple[str, str]]] = None, - ): - nodes = nodes or [] - edges = edges or [] - - self.forced_exception_class = exception_class - self.did_cancel: bool = False - super().__init__(args=MockArgs(), config=MockConfig(), manifest=None) - self.manifest = make_manifest(nodes=nodes) - digraph = nx.DiGraph() - for edge in edges: - digraph.add_edge(edge[0], edge[1]) - self.graph = Graph(digraph) - - def run_queue(self, pool): - """Override `run_queue` to raise a system exit""" - raise self.forced_exception_class() - - def _cancel_connections(self, pool): - """Override `_cancel_connections` to track whether it was called""" - self.did_cancel = True - - def get_node_selector(self): - """This is an `abstract_method` on `GraphRunnableTask`, thus we must implement it""" - selector = ResourceTypeSelector( - graph=self.graph, - manifest=self.manifest, - previous_state=self.previous_state, - resource_types=[NodeType.Model], - include_empty_nodes=True, - ) - return selector - - def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): - """This is an `abstract_method` on `GraphRunnableTask`, thus we must implement it""" - return None - - -class MockRunnableTaskIndependent(MockRunnableTask): - def get_run_mode(self) -> GraphRunnableMode: - return GraphRunnableMode.Independent - - -def test_graph_runnable_task_cancels_connection_on_system_exit(): - - safe_set_invocation_context() - - task = MockRunnableTask(exception_class=SystemExit) - - with pytest.raises(SystemExit): - task.execute_nodes() - - # If `did_cancel` is True, that means `_cancel_connections` was called - assert task.did_cancel is True - - -def test_graph_runnable_task_cancels_connection_on_keyboard_interrupt(): - - safe_set_invocation_context() - - task = MockRunnableTask(exception_class=KeyboardInterrupt) - - with pytest.raises(KeyboardInterrupt): - task.execute_nodes() - - # If `did_cancel` is True, that means `_cancel_connections` was called - assert task.did_cancel is True - - -def test_graph_runnable_task_doesnt_cancel_connection_on_generic_exception(): - task = MockRunnableTask(exception_class=Exception) - - with pytest.raises(Exception): - task.execute_nodes() - - # If `did_cancel` is True, that means `_cancel_connections` was called - assert task.did_cancel is False - - -def test_graph_runnable_preserves_edges_by_default(): - task = MockRunnableTask( - nodes=[ - MockNode("test", "upstream_node", fqn="model.test.upstream_node"), - MockNode("test", "downstream_node", fqn="model.test.downstream_node"), - ], - edges=[("model.test.upstream_node", "model.test.downstream_node")], - ) - assert task.get_run_mode() == GraphRunnableMode.Topological - graph_queue = task.get_graph_queue() - - assert graph_queue.queued == {"model.test.upstream_node"} - assert graph_queue.inner.queue == [(0, "model.test.upstream_node")] - - -def test_graph_runnable_preserves_edges_false(): - task = MockRunnableTaskIndependent( - nodes=[ - MockNode("test", "upstream_node", fqn="model.test.upstream_node"), - MockNode("test", "downstream_node", fqn="model.test.downstream_node"), - ], - edges=[("model.test.upstream_node", "model.test.downstream_node")], - ) - assert task.get_run_mode() == GraphRunnableMode.Independent - graph_queue = task.get_graph_queue() - - assert graph_queue.queued == {"model.test.downstream_node", "model.test.upstream_node"} - assert graph_queue.inner.queue == [ - (0, "model.test.downstream_node"), - (0, "model.test.upstream_node"), - ]