From 09243d1e5d178c7e3051bc63d33816cafaa4c98b Mon Sep 17 00:00:00 2001 From: Quigley Malcolm Date: Tue, 21 May 2024 09:51:13 -0700 Subject: [PATCH] Improved flags fixturing for for repository unit tests (#10190) * Add fixtures for setting and resettign flags for unit tests * Remove unnecessary `set_from_args` in non `unittest.TestCase` based unit tests In the previous commit we added a pytest fixture which sets and tears down the global flags arg via `set_from_args` for every pytest based unit test. Previously we had added a `set_from_args` in tests or test files to reset the global flags from if they were modified by a previous test. This is no longer necessary because of the work done in the previous commit. Note: We did not modify any tests that use the `unittest.TestCase` class because they don't use pytest fixtures. Thus those tests need to continue operating as they currently do until we shift them to pytest unit tests. * Utilize the new `args_for_flags` fixture for setting of flags in `test_contracts_graph_parsed.py` * Convert `test_compilation.py` from `TestCase` tests to pytest tests We did this so in the next commit we can drop the unnecessary `set_from_args` in the next commit. That will be it's own commit because converting these tests is a restructuring that doing separately makes things easier to follow. That is to say, all changes in this commit were just to convert the tests to pytest, no other changes were made. * Drop unnecessary `set_from_args` in `test_compilation.py` * Add return types to all methods in `test_compilation.py` * Reduce imports from `compilation` in `test_compilation.py` * Update `test_logging.py` now that we don't need to worry about global flags * Conditionally import `Generator` type for python 3.8 In python 3.9 `Generator` was moved to `collections.abc` and deprecated in `typing`. We still support 3.8 and thus need to be conditionally importing `Generator`. We should remove this in the future when we drop support for 3.8. --- tests/unit/conftest.py | 1 + tests/unit/context/test_context.py | 4 - tests/unit/context/test_query_header.py | 4 - tests/unit/events/test_logging.py | 9 +- tests/unit/parser/test_manifest.py | 4 - tests/unit/test_compilation.py | 149 +++++++++++----------- tests/unit/test_contracts_graph_parsed.py | 6 +- tests/unit/test_deprecations.py | 4 - tests/unit/test_events.py | 4 - tests/unit/test_graph_selection.py | 5 - tests/unit/test_proto_events.py | 6 - tests/unit/utils/flags.py | 33 +++++ tests/unit/utils/manifest.py | 5 - 13 files changed, 118 insertions(+), 116 deletions(-) create mode 100644 tests/unit/utils/flags.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6bad8e7eb85..f1823fb858f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,6 +7,7 @@ # All manifest related fixtures. from tests.unit.utils.adapter import * # noqa from tests.unit.utils.event_manager import * # noqa +from tests.unit.utils.flags import * # noqa from tests.unit.utils.manifest import * # noqa from tests.unit.utils.project import * # noqa diff --git a/tests/unit/context/test_context.py b/tests/unit/context/test_context.py index 3df0109191a..10e591093ee 100644 --- a/tests/unit/context/test_context.py +++ b/tests/unit/context/test_context.py @@ -1,5 +1,4 @@ import os -from argparse import Namespace from typing import Any, Dict, Set from unittest import mock @@ -19,14 +18,11 @@ UnitTestNode, UnitTestOverrides, ) -from dbt.flags import set_from_args from dbt.node_types import NodeType from dbt_common.events.functions import reset_metadata_vars from tests.unit.mock_adapter import adapter_factory from tests.unit.utils import clear_plugin, config_from_parts_or_dicts, inject_adapter -set_from_args(Namespace(WARN_ERROR=False), None) - class TestVar: @pytest.fixture diff --git a/tests/unit/context/test_query_header.py b/tests/unit/context/test_query_header.py index 40c0f1284d9..f14d28d40c4 100644 --- a/tests/unit/context/test_query_header.py +++ b/tests/unit/context/test_query_header.py @@ -1,16 +1,12 @@ import re -from argparse import Namespace from unittest import mock import pytest from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.context.query_header import generate_query_header_context -from dbt.flags import set_from_args from tests.unit.utils import config_from_parts_or_dicts -set_from_args(Namespace(WARN_ERROR=False), None) - class TestQueryHeaderContext: @pytest.fixture diff --git a/tests/unit/events/test_logging.py b/tests/unit/events/test_logging.py index 16441ad6de7..00284ecab78 100644 --- a/tests/unit/events/test_logging.py +++ b/tests/unit/events/test_logging.py @@ -1,5 +1,4 @@ from argparse import Namespace -from copy import deepcopy from pytest_mock import MockerFixture @@ -19,12 +18,10 @@ def test_clears_preexisting_event_manager_state(self) -> None: assert len(manager.loggers) == 1 assert len(manager.callbacks) == 1 - flags = deepcopy(get_flags()) - # setting both of these to none guarantees that no logger will be added - object.__setattr__(flags, "LOG_LEVEL", "none") - object.__setattr__(flags, "LOG_LEVEL_FILE", "none") + args = Namespace(log_level="none", log_level_file="none") + set_from_args(args, {}) - setup_event_logger(flags=flags) + setup_event_logger(get_flags()) assert len(manager.loggers) == 0 assert len(manager.callbacks) == 0 diff --git a/tests/unit/parser/test_manifest.py b/tests/unit/parser/test_manifest.py index b7d470a3552..1f10ee04f25 100644 --- a/tests/unit/parser/test_manifest.py +++ b/tests/unit/parser/test_manifest.py @@ -20,7 +20,6 @@ def test_partial_parse_file_path(self, patched_open, patched_os_exist, patched_s mock_project = MagicMock(RuntimeConfig) mock_project.project_target_path = "mock_target_path" patched_os_exist.return_value = True - set_from_args(Namespace(), {}) ManifestLoader(mock_project, {}) # by default we use the project_target_path patched_open.assert_called_with("mock_target_path/partial_parse.msgpack", "rb") @@ -33,7 +32,6 @@ def test_profile_hash_change(self, mock_project): # This test validate that the profile_hash is updated when the connection keys change profile_hash = "750bc99c1d64ca518536ead26b28465a224be5ffc918bf2a490102faa5a1bcf5" mock_project.credentials.connection_info.return_value = "test" - set_from_args(Namespace(), {}) manifest = ManifestLoader(mock_project, {}) assert manifest.manifest.state_check.profile_hash.checksum == profile_hash mock_project.credentials.connection_info.return_value = "test1" @@ -67,7 +65,6 @@ def test_partial_parse_safe_update_project_parser_files_partially( mock_saved_manifest.files = {} patched_read_manifest_for_partial_parse.return_value = mock_saved_manifest - set_from_args(Namespace(), {}) loader = ManifestLoader(mock_project, {}) loader.safe_update_project_parser_files_partially({}) @@ -150,7 +147,6 @@ def test_partial_parse_file_diff_flag( mock_file_diff = mocker.patch("dbt.parser.read_files.FileDiff.from_dict") mock_file_diff.return_value = FileDiff([], [], []) - set_from_args(Namespace(), {}) ManifestLoader.get_full_manifest(config=mock_project) assert not mock_file_diff.called diff --git a/tests/unit/test_compilation.py b/tests/unit/test_compilation.py index 458efb90901..c18e7fb15d2 100644 --- a/tests/unit/test_compilation.py +++ b/tests/unit/test_compilation.py @@ -1,17 +1,15 @@ import os import tempfile -import unittest -from argparse import Namespace from queue import Empty from unittest import mock -from dbt import compilation -from dbt.flags import set_from_args +import pytest + +from dbt.compilation import Graph, Linker from dbt.graph.cli import parse_difference +from dbt.graph.queue import GraphQueue from dbt.graph.selector import NodeSelector -set_from_args(Namespace(WARN_ERROR=False), None) - def _mock_manifest(nodes): config = mock.MagicMock(enabled=True) @@ -33,41 +31,48 @@ def _mock_manifest(nodes): return manifest -class LinkerTest(unittest.TestCase): - def setUp(self): - self.linker = compilation.Linker() +class TestLinker: + @pytest.fixture + def linker(self) -> Linker: + return Linker() - def test_linker_add_node(self): + def test_linker_add_node(self, linker: Linker) -> None: expected_nodes = ["A", "B", "C"] for node in expected_nodes: - self.linker.add_node(node) + linker.add_node(node) - actual_nodes = self.linker.nodes() + actual_nodes = linker.nodes() for node in expected_nodes: - self.assertIn(node, actual_nodes) + assert node in actual_nodes - self.assertEqual(len(actual_nodes), len(expected_nodes)) + assert len(actual_nodes) == len(expected_nodes) - def test_linker_write_graph(self): + def test_linker_write_graph(self, linker: Linker) -> None: expected_nodes = ["A", "B", "C"] for node in expected_nodes: - self.linker.add_node(node) + linker.add_node(node) manifest = _mock_manifest("ABC") (fd, fname) = tempfile.mkstemp() os.close(fd) try: - self.linker.write_graph(fname, manifest) + linker.write_graph(fname, manifest) assert os.path.exists(fname) finally: os.unlink(fname) - def assert_would_join(self, queue): + def assert_would_join(self, queue: GraphQueue) -> None: """test join() without timeout risk""" - self.assertEqual(queue.inner.unfinished_tasks, 0) - - def _get_graph_queue(self, manifest, include=None, exclude=None): - graph = compilation.Graph(self.linker.graph) + assert queue.inner.unfinished_tasks == 0 + + def _get_graph_queue( + self, + manifest, + linker: Linker, + include=None, + exclude=None, + ) -> GraphQueue: + graph = Graph(linker.graph) selector = NodeSelector(graph, manifest) # TODO: The "eager" string below needs to be replaced with programatic access # to the default value for the indirect selection parameter in @@ -77,114 +82,114 @@ def _get_graph_queue(self, manifest, include=None, exclude=None): spec = parse_difference(include, exclude) return selector.get_graph_queue(spec) - def test_linker_add_dependency(self): + def test_linker_add_dependency(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("A", "C"), ("B", "C")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - queue = self._get_graph_queue(_mock_manifest("ABC")) + queue = self._get_graph_queue(_mock_manifest("ABC"), linker) got = queue.get(block=False) - self.assertEqual(got.unique_id, "C") - with self.assertRaises(Empty): + assert got.unique_id == "C" + with pytest.raises(Empty): queue.get(block=False) - self.assertFalse(queue.empty()) + assert not queue.empty() queue.mark_done("C") - self.assertFalse(queue.empty()) + assert not queue.empty() got = queue.get(block=False) - self.assertEqual(got.unique_id, "B") - with self.assertRaises(Empty): + assert got.unique_id == "B" + with pytest.raises(Empty): queue.get(block=False) - self.assertFalse(queue.empty()) + assert not queue.empty() queue.mark_done("B") - self.assertFalse(queue.empty()) + assert not queue.empty() got = queue.get(block=False) - self.assertEqual(got.unique_id, "A") - with self.assertRaises(Empty): + assert got.unique_id == "A" + with pytest.raises(Empty): queue.get(block=False) - self.assertTrue(queue.empty()) + assert queue.empty() queue.mark_done("A") self.assert_would_join(queue) - self.assertTrue(queue.empty()) + assert queue.empty() - def test_linker_add_disjoint_dependencies(self): + def test_linker_add_disjoint_dependencies(self, linker: Linker) -> None: actual_deps = [("A", "B")] additional_node = "Z" for (l, r) in actual_deps: - self.linker.dependency(l, r) - self.linker.add_node(additional_node) + linker.dependency(l, r) + linker.add_node(additional_node) - queue = self._get_graph_queue(_mock_manifest("ABCZ")) + queue = self._get_graph_queue(_mock_manifest("ABCZ"), linker) # the first one we get must be B, it has the longest dep chain first = queue.get(block=False) - self.assertEqual(first.unique_id, "B") - self.assertFalse(queue.empty()) + assert first.unique_id == "B" + assert not queue.empty() queue.mark_done("B") - self.assertFalse(queue.empty()) + assert not queue.empty() second = queue.get(block=False) - self.assertIn(second.unique_id, {"A", "Z"}) - self.assertFalse(queue.empty()) + assert second.unique_id in {"A", "Z"} + assert not queue.empty() queue.mark_done(second.unique_id) - self.assertFalse(queue.empty()) + assert not queue.empty() third = queue.get(block=False) - self.assertIn(third.unique_id, {"A", "Z"}) - with self.assertRaises(Empty): + assert third.unique_id in {"A", "Z"} + with pytest.raises(Empty): queue.get(block=False) - self.assertNotEqual(second.unique_id, third.unique_id) - self.assertTrue(queue.empty()) + assert second.unique_id != third.unique_id + assert queue.empty() queue.mark_done(third.unique_id) self.assert_would_join(queue) - self.assertTrue(queue.empty()) + assert queue.empty() - def test_linker_dependencies_limited_to_some_nodes(self): + def test_linker_dependencies_limited_to_some_nodes(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("B", "C"), ("C", "D")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - queue = self._get_graph_queue(_mock_manifest("ABCD"), ["B"]) + queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["B"]) got = queue.get(block=False) - self.assertEqual(got.unique_id, "B") - self.assertTrue(queue.empty()) + assert got.unique_id == "B" + assert queue.empty() queue.mark_done("B") self.assert_would_join(queue) - queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), ["A", "B"]) + queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["A", "B"]) got = queue_2.get(block=False) - self.assertEqual(got.unique_id, "B") - self.assertFalse(queue_2.empty()) - with self.assertRaises(Empty): + assert got.unique_id == "B" + assert not queue_2.empty() + with pytest.raises(Empty): queue_2.get(block=False) queue_2.mark_done("B") - self.assertFalse(queue_2.empty()) + assert not queue_2.empty() got = queue_2.get(block=False) - self.assertEqual(got.unique_id, "A") - self.assertTrue(queue_2.empty()) - with self.assertRaises(Empty): + assert got.unique_id == "A" + assert queue_2.empty() + with pytest.raises(Empty): queue_2.get(block=False) - self.assertTrue(queue_2.empty()) + assert queue_2.empty() queue_2.mark_done("A") self.assert_would_join(queue_2) - def test__find_cycles__cycles(self): + def test__find_cycles__cycles(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("B", "C"), ("C", "A")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - self.assertIsNotNone(self.linker.find_cycles()) + assert linker.find_cycles() is not None - def test__find_cycles__no_cycles(self): + def test__find_cycles__no_cycles(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("B", "C"), ("C", "D")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - self.assertIsNone(self.linker.find_cycles()) + assert linker.find_cycles() is None diff --git a/tests/unit/test_contracts_graph_parsed.py b/tests/unit/test_contracts_graph_parsed.py index 7a62c394b22..b94271fab08 100644 --- a/tests/unit/test_contracts_graph_parsed.py +++ b/tests/unit/test_contracts_graph_parsed.py @@ -7,7 +7,6 @@ from hypothesis import given from hypothesis.strategies import builds, lists -from dbt import flags from dbt.artifacts.resources import ( ColumnInfo, Dimension, @@ -67,7 +66,10 @@ replace_config, ) -flags.set_from_args(Namespace(SEND_ANONYMOUS_USAGE_STATS=False), None) + +@pytest.fixture +def flags_for_args() -> Namespace: + return Namespace(SEND_ANONYMOUS_USAGE_STATS=False) @pytest.fixture diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index 85d1ea4add5..69d30132ef4 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -1,6 +1,3 @@ -from argparse import Namespace - -from dbt.flags import set_from_args from dbt.internal_deprecations import deprecated @@ -11,6 +8,5 @@ def to_be_decorated(): # simple test that the return value is not modified def test_deprecated_func(): - set_from_args(Namespace(WARN_ERROR=False), None) assert hasattr(to_be_decorated, "__wrapped__") assert to_be_decorated() == 5 diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index bd9892bcc6c..8a19b0ad39f 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -1,6 +1,5 @@ import logging import re -from argparse import Namespace from typing import TypeVar import pytest @@ -20,7 +19,6 @@ WarnLevel, ) from dbt.events.types import RunResultError -from dbt.flags import set_from_args from dbt.task.printer import print_run_result_error from dbt_common.events import types from dbt_common.events.base_types import msg_from_base_event @@ -29,8 +27,6 @@ from dbt_common.events.functions import msg_to_dict, msg_to_json from dbt_common.events.helpers import get_json_string_utcnow -set_from_args(Namespace(WARN_ERROR=False), None) - # takes in a class and finds any subclasses for it def get_all_subclasses(cls): diff --git a/tests/unit/test_graph_selection.py b/tests/unit/test_graph_selection.py index be283e59926..5d5cbf7469d 100644 --- a/tests/unit/test_graph_selection.py +++ b/tests/unit/test_graph_selection.py @@ -1,5 +1,4 @@ import string -from argparse import Namespace from unittest import mock import networkx as nx @@ -8,12 +7,8 @@ import dbt.graph.cli as graph_cli import dbt.graph.selector as graph_selector import dbt_common.exceptions -from dbt import flags -from dbt.contracts.project import ProjectFlags from dbt.node_types import NodeType -flags.set_from_args(Namespace(), ProjectFlags()) - def _get_graph(): integer_graph = nx.balanced_tree(2, 2, nx.DiGraph()) diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 7d369e6b00d..51fdf8a2024 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -1,5 +1,3 @@ -from argparse import Namespace - from google.protobuf.json_format import MessageToDict from dbt.adapters.events.types import PluginLoadError, RollbackFailed @@ -11,7 +9,6 @@ MainReportArgs, MainReportVersion, ) -from dbt.flags import set_from_args from dbt.version import installed from dbt_common.events import types_pb2 from dbt_common.events.base_types import EventLevel, msg_from_base_event @@ -22,9 +19,6 @@ reset_metadata_vars, ) -set_from_args(Namespace(WARN_ERROR=False), None) - - info_keys = { "name", "code", diff --git a/tests/unit/utils/flags.py b/tests/unit/utils/flags.py new file mode 100644 index 00000000000..20bb4a44ea0 --- /dev/null +++ b/tests/unit/utils/flags.py @@ -0,0 +1,33 @@ +import sys +from argparse import Namespace + +if sys.version_info < (3, 9): + from typing import Generator +else: + from collections.abc import Generator + +import pytest + +from dbt.flags import set_from_args + + +@pytest.fixture +def args_for_flags() -> Namespace: + """Defines the namespace args to be used in `set_from_args` of `set_test_flags` fixture. + + This fixture is meant to be overrided by tests that need specific flags to be set. + """ + return Namespace() + + +@pytest.fixture(autouse=True) +def set_test_flags(args_for_flags: Namespace) -> Generator[None, None, None]: + """Sets up and tears down the global flags for every pytest unit test + + Override `args_for_flags` fixture as needed to set any specific flags. + """ + set_from_args(args_for_flags, {}) + # fixtures stop setup upon yield + yield None + # everything after yield is run at test teardown + set_from_args(Namespace(), {}) diff --git a/tests/unit/utils/manifest.py b/tests/unit/utils/manifest.py index c62d0bd0edf..a7c269cdab2 100644 --- a/tests/unit/utils/manifest.py +++ b/tests/unit/utils/manifest.py @@ -1,5 +1,3 @@ -from argparse import Namespace - import pytest from dbt_semantic_interfaces.type_enums import MetricType @@ -36,11 +34,8 @@ UnitTestDefinition, ) from dbt.contracts.graph.unparsed import UnitTestInputFixture, UnitTestOutputFixture -from dbt.flags import set_from_args from dbt.node_types import NodeType -set_from_args(Namespace(WARN_ERROR=False), None) - def make_model( pkg,