From f5f9591d09550c4cfd425c2cee702c9b0ade5e33 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 18 Apr 2024 10:44:28 -0400 Subject: [PATCH] [Robust Testing] Move tests to tests/unit/context (#9966) --- core/dbt/cli/requires.py | 2 +- core/dbt/context/manifest.py | 10 -- core/dbt/context/query_header.py | 13 ++ core/dbt/parser/manifest.py | 2 +- tests/unit/{ => context}/test_context.py | 170 ++++++++++-------- tests/unit/{ => context}/test_providers.py | 0 .../test_query_header.py} | 49 ++--- 7 files changed, 141 insertions(+), 105 deletions(-) create mode 100644 core/dbt/context/query_header.py rename tests/unit/{ => context}/test_context.py (81%) rename tests/unit/{ => context}/test_providers.py (100%) rename tests/unit/{test_query_headers.py => context/test_query_header.py} (50%) diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index ccd5ffc7150..75c81ebd7e1 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -15,7 +15,7 @@ from dbt.cli.flags import Flags from dbt.config import RuntimeConfig from dbt.config.runtime import load_project, load_profile, UnsetProfile -from dbt.context.manifest import generate_query_header_context +from dbt.context.query_header import generate_query_header_context from dbt_common.events.base_types import EventLevel from dbt_common.events.functions import ( diff --git a/core/dbt/context/manifest.py b/core/dbt/context/manifest.py index d55d3ad0f21..0d95fd3b95f 100644 --- a/core/dbt/context/manifest.py +++ b/core/dbt/context/manifest.py @@ -71,13 +71,3 @@ def to_dict(self): @contextproperty() def context_macro_stack(self): return self.macro_stack - - -class QueryHeaderContext(ManifestContext): - def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None: - super().__init__(config, manifest, config.project_name) - - -def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest): - ctx = QueryHeaderContext(config, manifest) - return ctx.to_dict() diff --git a/core/dbt/context/query_header.py b/core/dbt/context/query_header.py new file mode 100644 index 00000000000..95c5a0b7a8f --- /dev/null +++ b/core/dbt/context/query_header.py @@ -0,0 +1,13 @@ +from dbt.adapters.contracts.connection import AdapterRequiredConfig +from dbt.context.manifest import ManifestContext +from dbt.contracts.graph.manifest import Manifest + + +class QueryHeaderContext(ManifestContext): + def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None: + super().__init__(config, manifest, config.project_name) + + +def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest): + ctx = QueryHeaderContext(config, manifest) + return ctx.to_dict() diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index a0cc49faa20..5e406d81d03 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -19,7 +19,7 @@ from itertools import chain import time -from dbt.context.manifest import generate_query_header_context +from dbt.context.query_header import generate_query_header_context from dbt.contracts.graph.semantic_manifest import SemanticManifest from dbt_common.events.base_types import EventLevel from dbt_common.exceptions.base import DbtValidationError diff --git a/tests/unit/test_context.py b/tests/unit/context/test_context.py similarity index 81% rename from tests/unit/test_context.py rename to tests/unit/context/test_context.py index fd23da53c17..6070c24a1b7 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/context/test_context.py @@ -1,4 +1,3 @@ -import unittest import os from typing import Set, Dict, Any from unittest import mock @@ -17,26 +16,28 @@ UnitTestOverrides, ) from dbt.config.project import VarProvider -from dbt.context import base, providers, docs, manifest, macros +from dbt.context import base, providers, docs, macros, query_header from dbt.contracts.files import FileHash from dbt_common.events.functions import reset_metadata_vars +from dbt.flags import set_from_args from dbt.node_types import NodeType import dbt_common.exceptions -from .utils import ( + +from tests.unit.utils import ( config_from_parts_or_dicts, inject_adapter, clear_plugin, ) -from .mock_adapter import adapter_factory -from dbt.flags import set_from_args +from tests.unit.mock_adapter import adapter_factory from argparse import Namespace set_from_args(Namespace(WARN_ERROR=False), None) -class TestVar(unittest.TestCase): - def setUp(self): - self.model = ModelNode( +class TestVar: + @pytest.fixture + def model(self): + return ModelNode( alias="model_one", name="model_one", database="dbt", @@ -70,91 +71,114 @@ def setUp(self): columns={}, checksum=FileHash.from_contents(""), ) - self.context = mock.MagicMock() - self.provider = VarProvider({}) - self.config = mock.MagicMock( - config_version=2, vars=self.provider, cli_vars={}, project_name="root" - ) - def test_var_default_something(self): - self.config.cli_vars = {"foo": "baz"} - var = providers.RuntimeVar(self.context, self.config, self.model) - self.assertEqual(var("foo"), "baz") - self.assertEqual(var("foo", "bar"), "baz") + @pytest.fixture + def context(self): + return mock.MagicMock() + + @pytest.fixture + def provider(self): + return VarProvider({}) + + @pytest.fixture + def config(self, provider): + return mock.MagicMock(config_version=2, vars=provider, cli_vars={}, project_name="root") + + def test_var_default_something(self, model, config, context): + config.cli_vars = {"foo": "baz"} + var = providers.RuntimeVar(context, config, model) + + assert var("foo") == "baz" + assert var("foo", "bar") == "baz" + + def test_var_default_none(self, model, config, context): + config.cli_vars = {"foo": None} + var = providers.RuntimeVar(context, config, model) - def test_var_default_none(self): - self.config.cli_vars = {"foo": None} - var = providers.RuntimeVar(self.context, self.config, self.model) - self.assertEqual(var("foo"), None) - self.assertEqual(var("foo", "bar"), None) + assert var("foo") is None + assert var("foo", "bar") is None - def test_var_not_defined(self): - var = providers.RuntimeVar(self.context, self.config, self.model) + def test_var_not_defined(self, model, config, context): + var = providers.RuntimeVar(self.context, config, model) - self.assertEqual(var("foo", "bar"), "bar") - with self.assertRaises(dbt_common.exceptions.CompilationError): + assert var("foo", "bar") == "bar" + with pytest.raises(dbt_common.exceptions.CompilationError): var("foo") - def test_parser_var_default_something(self): - self.config.cli_vars = {"foo": "baz"} - var = providers.ParseVar(self.context, self.config, self.model) - self.assertEqual(var("foo"), "baz") - self.assertEqual(var("foo", "bar"), "baz") + def test_parser_var_default_something(self, model, config, context): + config.cli_vars = {"foo": "baz"} + var = providers.ParseVar(context, config, model) + assert var("foo") == "baz" + assert var("foo", "bar") == "baz" - def test_parser_var_default_none(self): - self.config.cli_vars = {"foo": None} - var = providers.ParseVar(self.context, self.config, self.model) - self.assertEqual(var("foo"), None) - self.assertEqual(var("foo", "bar"), None) + def test_parser_var_default_none(self, model, config, context): + config.cli_vars = {"foo": None} + var = providers.ParseVar(context, config, model) + assert var("foo") is None + assert var("foo", "bar") is None - def test_parser_var_not_defined(self): + def test_parser_var_not_defined(self, model, config, context): # at parse-time, we should not raise if we encounter a missing var # that way disabled models don't get parse errors - var = providers.ParseVar(self.context, self.config, self.model) + var = providers.ParseVar(context, config, model) - self.assertEqual(var("foo", "bar"), "bar") - self.assertEqual(var("foo"), None) + assert var("foo", "bar") == "bar" + assert var("foo") is None -class TestParseWrapper(unittest.TestCase): - def setUp(self): - self.mock_config = mock.MagicMock() - self.mock_mp_context = mock.MagicMock() +class TestParseWrapper: + @pytest.fixture + def mock_adapter(self): + mock_config = mock.MagicMock() + mock_mp_context = mock.MagicMock() adapter_class = adapter_factory() - self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context) - self.namespace = mock.MagicMock() - self.wrapper = providers.ParseDatabaseWrapper(self.mock_adapter, self.namespace) - self.responder = self.mock_adapter.responder - - def test_unwrapped_method(self): - self.assertEqual(self.wrapper.quote("test_value"), '"test_value"') - self.responder.quote.assert_called_once_with("test_value") - - def test_wrapped_method(self): - found = self.wrapper.get_relation("database", "schema", "identifier") - self.assertEqual(found, None) - self.responder.get_relation.assert_not_called() - - -class TestRuntimeWrapper(unittest.TestCase): - def setUp(self): - self.mock_config = mock.MagicMock() - self.mock_mp_context = mock.MagicMock() - self.mock_config.quoting = { + return adapter_class(mock_config, mock_mp_context) + + @pytest.fixture + def wrapper(self, mock_adapter): + namespace = mock.MagicMock() + return providers.ParseDatabaseWrapper(mock_adapter, namespace) + + @pytest.fixture + def responder(self, mock_adapter): + return mock_adapter.responder + + def test_unwrapped_method(self, wrapper, responder): + assert wrapper.quote("test_value") == '"test_value"' + responder.quote.assert_called_once_with("test_value") + + def test_wrapped_method(self, wrapper, responder): + found = wrapper.get_relation("database", "schema", "identifier") + assert found is None + responder.get_relation.assert_not_called() + + +class TestRuntimeWrapper: + @pytest.fixture + def mock_adapter(self): + mock_config = mock.MagicMock() + mock_config.quoting = { "database": True, "schema": True, "identifier": True, } + mock_mp_context = mock.MagicMock() adapter_class = adapter_factory() - self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context) - self.namespace = mock.MagicMock() - self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace) - self.responder = self.mock_adapter.responder + return adapter_class(mock_config, mock_mp_context) + + @pytest.fixture + def wrapper(self, mock_adapter): + namespace = mock.MagicMock() + return providers.RuntimeDatabaseWrapper(mock_adapter, namespace) + + @pytest.fixture + def responder(self, mock_adapter): + return mock_adapter.responder - def test_unwrapped_method(self): + def test_unwrapped_method(self, wrapper, responder): # the 'quote' method isn't wrapped, we should get our expected inputs - self.assertEqual(self.wrapper.quote("test_value"), '"test_value"') - self.responder.quote.assert_called_once_with("test_value") + assert wrapper.quote("test_value") == '"test_value"' + responder.quote.assert_called_once_with("test_value") def assert_has_keys(required_keys: Set[str], maybe_keys: Set[str], ctx: Dict[str, Any]): @@ -417,7 +441,7 @@ def postgres_adapter(config_postgres, get_adapter): def test_query_header_context(config_postgres, manifest_fx): - ctx = manifest.generate_query_header_context( + ctx = query_header.generate_query_header_context( config=config_postgres, manifest=manifest_fx, ) diff --git a/tests/unit/test_providers.py b/tests/unit/context/test_providers.py similarity index 100% rename from tests/unit/test_providers.py rename to tests/unit/context/test_providers.py diff --git a/tests/unit/test_query_headers.py b/tests/unit/context/test_query_header.py similarity index 50% rename from tests/unit/test_query_headers.py rename to tests/unit/context/test_query_header.py index 2be9b59bd4d..aa9e99821a2 100644 --- a/tests/unit/test_query_headers.py +++ b/tests/unit/context/test_query_header.py @@ -1,8 +1,9 @@ +import pytest import re -from unittest import TestCase, mock +from unittest import mock from dbt.adapters.base.query_headers import MacroQueryStringSetter -from dbt.context.manifest import generate_query_header_context +from dbt.context.query_header import generate_query_header_context from tests.unit.utils import config_from_parts_or_dicts from dbt.flags import set_from_args @@ -11,9 +12,10 @@ set_from_args(Namespace(WARN_ERROR=False), None) -class TestQueryHeaders(TestCase): - def setUp(self): - self.profile_cfg = { +class TestQueryHeaderContext: + @pytest.fixture + def profile_cfg(self): + return { "outputs": { "test": { "type": "postgres", @@ -27,33 +29,40 @@ def setUp(self): }, "target": "test", } - self.project_cfg = { + + @pytest.fixture + def project_cfg(self): + return { "name": "query_headers", "version": "0.1", "profile": "test", "config-version": 2, } - self.query = "SELECT 1;" - def test_comment_should_prepend_query_by_default(self): - config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) + @pytest.fixture + def query(self): + return "SELECT 1;" + + def test_comment_should_prepend_query_by_default(self, profile_cfg, project_cfg, query): + config = config_from_parts_or_dicts(project_cfg, profile_cfg) query_header_context = generate_query_header_context(config, mock.MagicMock(macros={})) query_header = MacroQueryStringSetter(config, query_header_context) - sql = query_header.add(self.query) - self.assertTrue(re.match(f"^\/\*.*\*\/\n{self.query}$", sql)) # noqa: [W605] + sql = query_header.add(query) + assert re.match(f"^\/\*.*\*\/\n{query}$", sql) # noqa: [W605] - def test_append_comment(self): - self.project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}}) - config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) + def test_append_comment(self, profile_cfg, project_cfg, query): + project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}}) + config = config_from_parts_or_dicts(project_cfg, profile_cfg) query_header_context = generate_query_header_context(config, mock.MagicMock(macros={})) query_header = MacroQueryStringSetter(config, query_header_context) - sql = query_header.add(self.query) - self.assertEqual(sql, f"{self.query[:-1]}\n/* executed by dbt */;") + sql = query_header.add(query) + + assert sql == f"{query[:-1]}\n/* executed by dbt */;" - def test_disable_query_comment(self): - self.project_cfg.update({"query-comment": ""}) - config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) + def test_disable_query_comment(self, profile_cfg, project_cfg, query): + project_cfg.update({"query-comment": ""}) + config = config_from_parts_or_dicts(project_cfg, profile_cfg) query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={})) - self.assertEqual(query_header.add(self.query), self.query) + assert query_header.add(query) == query