Skip to content

Commit

Permalink
[Robust Testing] Move tests to tests/unit/context (#9966)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Apr 18, 2024
1 parent 61727ab commit f5f9591
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 105 deletions.
2 changes: 1 addition & 1 deletion core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig
from dbt.config.runtime import load_project, load_profile, UnsetProfile
from dbt.context.manifest import generate_query_header_context
from dbt.context.query_header import generate_query_header_context

from dbt_common.events.base_types import EventLevel
from dbt_common.events.functions import (
Expand Down
10 changes: 0 additions & 10 deletions core/dbt/context/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,3 @@ def to_dict(self):
@contextproperty()
def context_macro_stack(self):
return self.macro_stack


class QueryHeaderContext(ManifestContext):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
super().__init__(config, manifest, config.project_name)


def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest):
ctx = QueryHeaderContext(config, manifest)
return ctx.to_dict()
13 changes: 13 additions & 0 deletions core/dbt/context/query_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dbt.adapters.contracts.connection import AdapterRequiredConfig
from dbt.context.manifest import ManifestContext
from dbt.contracts.graph.manifest import Manifest


class QueryHeaderContext(ManifestContext):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
super().__init__(config, manifest, config.project_name)


def generate_query_header_context(config: AdapterRequiredConfig, manifest: Manifest):
ctx = QueryHeaderContext(config, manifest)
return ctx.to_dict()
2 changes: 1 addition & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from itertools import chain
import time

from dbt.context.manifest import generate_query_header_context
from dbt.context.query_header import generate_query_header_context
from dbt.contracts.graph.semantic_manifest import SemanticManifest
from dbt_common.events.base_types import EventLevel
from dbt_common.exceptions.base import DbtValidationError
Expand Down
170 changes: 97 additions & 73 deletions tests/unit/test_context.py → tests/unit/context/test_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest
import os
from typing import Set, Dict, Any
from unittest import mock
Expand All @@ -17,26 +16,28 @@
UnitTestOverrides,
)
from dbt.config.project import VarProvider
from dbt.context import base, providers, docs, manifest, macros
from dbt.context import base, providers, docs, macros, query_header
from dbt.contracts.files import FileHash
from dbt_common.events.functions import reset_metadata_vars
from dbt.flags import set_from_args
from dbt.node_types import NodeType
import dbt_common.exceptions
from .utils import (

from tests.unit.utils import (
config_from_parts_or_dicts,
inject_adapter,
clear_plugin,
)
from .mock_adapter import adapter_factory
from dbt.flags import set_from_args
from tests.unit.mock_adapter import adapter_factory
from argparse import Namespace

set_from_args(Namespace(WARN_ERROR=False), None)


class TestVar(unittest.TestCase):
def setUp(self):
self.model = ModelNode(
class TestVar:
@pytest.fixture
def model(self):
return ModelNode(
alias="model_one",
name="model_one",
database="dbt",
Expand Down Expand Up @@ -70,91 +71,114 @@ def setUp(self):
columns={},
checksum=FileHash.from_contents(""),
)
self.context = mock.MagicMock()
self.provider = VarProvider({})
self.config = mock.MagicMock(
config_version=2, vars=self.provider, cli_vars={}, project_name="root"
)

def test_var_default_something(self):
self.config.cli_vars = {"foo": "baz"}
var = providers.RuntimeVar(self.context, self.config, self.model)
self.assertEqual(var("foo"), "baz")
self.assertEqual(var("foo", "bar"), "baz")
@pytest.fixture
def context(self):
return mock.MagicMock()

@pytest.fixture
def provider(self):
return VarProvider({})

@pytest.fixture
def config(self, provider):
return mock.MagicMock(config_version=2, vars=provider, cli_vars={}, project_name="root")

def test_var_default_something(self, model, config, context):
config.cli_vars = {"foo": "baz"}
var = providers.RuntimeVar(context, config, model)

assert var("foo") == "baz"
assert var("foo", "bar") == "baz"

def test_var_default_none(self, model, config, context):
config.cli_vars = {"foo": None}
var = providers.RuntimeVar(context, config, model)

def test_var_default_none(self):
self.config.cli_vars = {"foo": None}
var = providers.RuntimeVar(self.context, self.config, self.model)
self.assertEqual(var("foo"), None)
self.assertEqual(var("foo", "bar"), None)
assert var("foo") is None
assert var("foo", "bar") is None

def test_var_not_defined(self):
var = providers.RuntimeVar(self.context, self.config, self.model)
def test_var_not_defined(self, model, config, context):
var = providers.RuntimeVar(self.context, config, model)

self.assertEqual(var("foo", "bar"), "bar")
with self.assertRaises(dbt_common.exceptions.CompilationError):
assert var("foo", "bar") == "bar"
with pytest.raises(dbt_common.exceptions.CompilationError):
var("foo")

def test_parser_var_default_something(self):
self.config.cli_vars = {"foo": "baz"}
var = providers.ParseVar(self.context, self.config, self.model)
self.assertEqual(var("foo"), "baz")
self.assertEqual(var("foo", "bar"), "baz")
def test_parser_var_default_something(self, model, config, context):
config.cli_vars = {"foo": "baz"}
var = providers.ParseVar(context, config, model)
assert var("foo") == "baz"
assert var("foo", "bar") == "baz"

def test_parser_var_default_none(self):
self.config.cli_vars = {"foo": None}
var = providers.ParseVar(self.context, self.config, self.model)
self.assertEqual(var("foo"), None)
self.assertEqual(var("foo", "bar"), None)
def test_parser_var_default_none(self, model, config, context):
config.cli_vars = {"foo": None}
var = providers.ParseVar(context, config, model)
assert var("foo") is None
assert var("foo", "bar") is None

def test_parser_var_not_defined(self):
def test_parser_var_not_defined(self, model, config, context):
# at parse-time, we should not raise if we encounter a missing var
# that way disabled models don't get parse errors
var = providers.ParseVar(self.context, self.config, self.model)
var = providers.ParseVar(context, config, model)

self.assertEqual(var("foo", "bar"), "bar")
self.assertEqual(var("foo"), None)
assert var("foo", "bar") == "bar"
assert var("foo") is None


class TestParseWrapper(unittest.TestCase):
def setUp(self):
self.mock_config = mock.MagicMock()
self.mock_mp_context = mock.MagicMock()
class TestParseWrapper:
@pytest.fixture
def mock_adapter(self):
mock_config = mock.MagicMock()
mock_mp_context = mock.MagicMock()
adapter_class = adapter_factory()
self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context)
self.namespace = mock.MagicMock()
self.wrapper = providers.ParseDatabaseWrapper(self.mock_adapter, self.namespace)
self.responder = self.mock_adapter.responder

def test_unwrapped_method(self):
self.assertEqual(self.wrapper.quote("test_value"), '"test_value"')
self.responder.quote.assert_called_once_with("test_value")

def test_wrapped_method(self):
found = self.wrapper.get_relation("database", "schema", "identifier")
self.assertEqual(found, None)
self.responder.get_relation.assert_not_called()


class TestRuntimeWrapper(unittest.TestCase):
def setUp(self):
self.mock_config = mock.MagicMock()
self.mock_mp_context = mock.MagicMock()
self.mock_config.quoting = {
return adapter_class(mock_config, mock_mp_context)

@pytest.fixture
def wrapper(self, mock_adapter):
namespace = mock.MagicMock()
return providers.ParseDatabaseWrapper(mock_adapter, namespace)

@pytest.fixture
def responder(self, mock_adapter):
return mock_adapter.responder

def test_unwrapped_method(self, wrapper, responder):
assert wrapper.quote("test_value") == '"test_value"'
responder.quote.assert_called_once_with("test_value")

def test_wrapped_method(self, wrapper, responder):
found = wrapper.get_relation("database", "schema", "identifier")
assert found is None
responder.get_relation.assert_not_called()


class TestRuntimeWrapper:
@pytest.fixture
def mock_adapter(self):
mock_config = mock.MagicMock()
mock_config.quoting = {
"database": True,
"schema": True,
"identifier": True,
}
mock_mp_context = mock.MagicMock()
adapter_class = adapter_factory()
self.mock_adapter = adapter_class(self.mock_config, self.mock_mp_context)
self.namespace = mock.MagicMock()
self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace)
self.responder = self.mock_adapter.responder
return adapter_class(mock_config, mock_mp_context)

@pytest.fixture
def wrapper(self, mock_adapter):
namespace = mock.MagicMock()
return providers.RuntimeDatabaseWrapper(mock_adapter, namespace)

@pytest.fixture
def responder(self, mock_adapter):
return mock_adapter.responder

def test_unwrapped_method(self):
def test_unwrapped_method(self, wrapper, responder):
# the 'quote' method isn't wrapped, we should get our expected inputs
self.assertEqual(self.wrapper.quote("test_value"), '"test_value"')
self.responder.quote.assert_called_once_with("test_value")
assert wrapper.quote("test_value") == '"test_value"'
responder.quote.assert_called_once_with("test_value")


def assert_has_keys(required_keys: Set[str], maybe_keys: Set[str], ctx: Dict[str, Any]):
Expand Down Expand Up @@ -417,7 +441,7 @@ def postgres_adapter(config_postgres, get_adapter):


def test_query_header_context(config_postgres, manifest_fx):
ctx = manifest.generate_query_header_context(
ctx = query_header.generate_query_header_context(
config=config_postgres,
manifest=manifest_fx,
)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
import re
from unittest import TestCase, mock
from unittest import mock

from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.context.manifest import generate_query_header_context
from dbt.context.query_header import generate_query_header_context

from tests.unit.utils import config_from_parts_or_dicts
from dbt.flags import set_from_args
Expand All @@ -11,9 +12,10 @@
set_from_args(Namespace(WARN_ERROR=False), None)


class TestQueryHeaders(TestCase):
def setUp(self):
self.profile_cfg = {
class TestQueryHeaderContext:
@pytest.fixture
def profile_cfg(self):
return {
"outputs": {
"test": {
"type": "postgres",
Expand All @@ -27,33 +29,40 @@ def setUp(self):
},
"target": "test",
}
self.project_cfg = {

@pytest.fixture
def project_cfg(self):
return {
"name": "query_headers",
"version": "0.1",
"profile": "test",
"config-version": 2,
}
self.query = "SELECT 1;"

def test_comment_should_prepend_query_by_default(self):
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
@pytest.fixture
def query(self):
return "SELECT 1;"

def test_comment_should_prepend_query_by_default(self, profile_cfg, project_cfg, query):
config = config_from_parts_or_dicts(project_cfg, profile_cfg)

query_header_context = generate_query_header_context(config, mock.MagicMock(macros={}))
query_header = MacroQueryStringSetter(config, query_header_context)
sql = query_header.add(self.query)
self.assertTrue(re.match(f"^\/\*.*\*\/\n{self.query}$", sql)) # noqa: [W605]
sql = query_header.add(query)
assert re.match(f"^\/\*.*\*\/\n{query}$", sql) # noqa: [W605]

def test_append_comment(self):
self.project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}})
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
def test_append_comment(self, profile_cfg, project_cfg, query):
project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}})
config = config_from_parts_or_dicts(project_cfg, profile_cfg)

query_header_context = generate_query_header_context(config, mock.MagicMock(macros={}))
query_header = MacroQueryStringSetter(config, query_header_context)
sql = query_header.add(self.query)
self.assertEqual(sql, f"{self.query[:-1]}\n/* executed by dbt */;")
sql = query_header.add(query)

assert sql == f"{query[:-1]}\n/* executed by dbt */;"

def test_disable_query_comment(self):
self.project_cfg.update({"query-comment": ""})
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
def test_disable_query_comment(self, profile_cfg, project_cfg, query):
project_cfg.update({"query-comment": ""})
config = config_from_parts_or_dicts(project_cfg, profile_cfg)
query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={}))
self.assertEqual(query_header.add(self.query), self.query)
assert query_header.add(query) == query

0 comments on commit f5f9591

Please sign in to comment.