diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 263049420..d8d47aa40 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from dbt.common.contracts.constraints import ConstraintType from typing import Optional, Set, Any, Dict, Type from collections import namedtuple from dbt.adapters.base import PythonJobHelper @@ -7,7 +8,6 @@ from dbt.adapters.base.meta import available from dbt.adapters.sql import SQLAdapter from dbt.adapters.contracts.connection import AdapterResponse -from dbt.contracts.graph.nodes import ConstraintType from dbt.adapters.events.logging import AdapterLogger diff --git a/dbt/adapters/redshift/relation.py b/dbt/adapters/redshift/relation.py index 30afb2146..db391f0a4 100644 --- a/dbt/adapters/redshift/relation.py +++ b/dbt/adapters/redshift/relation.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from dbt.adapters.contracts.relation import RelationConfig from typing import Optional from dbt.adapters.base.relation import BaseRelation @@ -7,8 +8,6 @@ RelationConfigChangeAction, RelationResults, ) -from dbt.context.providers import RuntimeConfigObject -from dbt.contracts.graph.nodes import ModelNode from dbt.adapters.base import RelationType from dbt.common.exceptions import DbtRuntimeError @@ -60,31 +59,28 @@ def relation_max_name_length(self): return MAX_CHARACTERS_IN_IDENTIFIER @classmethod - def from_runtime_config(cls, runtime_config: RuntimeConfigObject) -> RelationConfigBase: - model_node: ModelNode = runtime_config.model - relation_type: str = model_node.config.materialized + def from_config(cls, config: RelationConfig) -> RelationConfigBase: + relation_type: str = config.config.materialized # type: ignore if relation_config := cls.relation_configs.get(relation_type): - return relation_config.from_model_node(model_node) + return relation_config.from_config(relation_config) # type: ignore raise DbtRuntimeError( - f"from_runtime_config() is not supported for the provided relation type: {relation_type}" + f"from_config() is not supported for the provided relation type: {relation_type}" ) @classmethod def materialized_view_config_changeset( - cls, relation_results: RelationResults, runtime_config: RuntimeConfigObject + cls, relation_results: RelationResults, relation_config: RelationConfig ) -> Optional[RedshiftMaterializedViewConfigChangeset]: config_change_collection = RedshiftMaterializedViewConfigChangeset() existing_materialized_view = RedshiftMaterializedViewConfig.from_relation_results( relation_results ) - new_materialized_view = RedshiftMaterializedViewConfig.from_model_node( - runtime_config.model + new_materialized_view = RedshiftMaterializedViewConfig.from_relation_config( + relation_config ) - assert isinstance(existing_materialized_view, RedshiftMaterializedViewConfig) - assert isinstance(new_materialized_view, RedshiftMaterializedViewConfig) if new_materialized_view.autorefresh != existing_materialized_view.autorefresh: config_change_collection.autorefresh = RedshiftAutoRefreshConfigChange( diff --git a/dbt/adapters/redshift/relation_configs/base.py b/dbt/adapters/redshift/relation_configs/base.py index b64571ac5..c4faab664 100644 --- a/dbt/adapters/redshift/relation_configs/base.py +++ b/dbt/adapters/redshift/relation_configs/base.py @@ -1,14 +1,14 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import agate from dbt.adapters.base.relation import Policy -from dbt.adapters.contracts.relation import ComponentName +from dbt.adapters.contracts.relation import ComponentName, RelationConfig from dbt.adapters.relation_configs import ( RelationConfigBase, RelationResults, ) -from dbt.contracts.graph.nodes import ModelNode +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.policies import ( RedshiftIncludePolicy, @@ -31,25 +31,25 @@ def quote_policy(cls) -> Policy: return RedshiftQuotePolicy() @classmethod - def from_model_node(cls, model_node: ModelNode) -> "RelationConfigBase": - relation_config = cls.parse_model_node(model_node) - relation = cls.from_dict(relation_config) - return relation + def from_relation_config(cls, relation_config: RelationConfig) -> Self: + relation_config_dict = cls.parse_relation_config(relation_config) + relation = cls.from_dict(relation_config_dict) + return relation # type: ignore @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> Dict: raise NotImplementedError( - "`parse_model_node()` needs to be implemented on this RelationConfigBase instance" + "`parse_relation_config()` needs to be implemented on this RelationConfigBase instance" ) @classmethod - def from_relation_results(cls, relation_results: RelationResults) -> "RelationConfigBase": + def from_relation_results(cls, relation_results: RelationResults) -> Self: relation_config = cls.parse_relation_results(relation_results) relation = cls.from_dict(relation_config) - return relation + return relation # type: ignore @classmethod - def parse_relation_results(cls, relation_results: RelationResults) -> dict: + def parse_relation_results(cls, relation_results: RelationResults) -> Dict: raise NotImplementedError( "`parse_relation_results()` needs to be implemented on this RelationConfigBase instance" ) diff --git a/dbt/adapters/redshift/relation_configs/dist.py b/dbt/adapters/redshift/relation_configs/dist.py index 65be4cd35..58812ee57 100644 --- a/dbt/adapters/redshift/relation_configs/dist.py +++ b/dbt/adapters/redshift/relation_configs/dist.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Optional, Set +from dbt.adapters.contracts.relation import RelationConfig +from typing import Optional, Set, Dict import agate from dbt.adapters.relation_configs import ( @@ -8,9 +9,9 @@ RelationConfigValidationMixin, RelationConfigValidationRule, ) -from dbt.contracts.graph.nodes import ModelNode from dbt.common.dataclass_schema import StrEnum from dbt.common.exceptions import DbtRuntimeError +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase @@ -65,21 +66,21 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: } @classmethod - def from_dict(cls, config_dict) -> "RedshiftDistConfig": + def from_dict(cls, config_dict) -> Self: kwargs_dict = { "diststyle": config_dict.get("diststyle"), "distkey": config_dict.get("distkey"), } - dist: "RedshiftDistConfig" = super().from_dict(kwargs_dict) # type: ignore + dist: Self = super().from_dict(kwargs_dict) # type: ignore return dist @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> dict: """ Translate ModelNode objects from the user-provided config into a standard dictionary. Args: - model_node: the description of the distkey and diststyle from the user in this format: + relation_config: the description of the distkey and diststyle from the user in this format: { "dist": any("auto", "even", "all") or "" @@ -87,7 +88,7 @@ def parse_model_node(cls, model_node: ModelNode) -> dict: Returns: a standard dictionary describing this `RedshiftDistConfig` instance """ - dist = model_node.config.extra.get("dist", "") + dist = relation_config.config.extra.get("dist", "") # type: ignore diststyle = dist.lower() @@ -107,7 +108,7 @@ def parse_model_node(cls, model_node: ModelNode) -> dict: return config @classmethod - def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict: + def parse_relation_results(cls, relation_results_entry: agate.Row) -> Dict: """ Translate agate objects from the database into a standard dictionary. diff --git a/dbt/adapters/redshift/relation_configs/materialized_view.py b/dbt/adapters/redshift/relation_configs/materialized_view.py index 127af2f63..60b369e0a 100644 --- a/dbt/adapters/redshift/relation_configs/materialized_view.py +++ b/dbt/adapters/redshift/relation_configs/materialized_view.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Optional, Set, Dict, Any import agate from dbt.adapters.relation_configs import ( @@ -8,9 +8,9 @@ RelationConfigValidationMixin, RelationConfigValidationRule, ) -from dbt.contracts.graph.nodes import ModelNode -from dbt.adapters.contracts.relation import ComponentName +from dbt.adapters.contracts.relation import ComponentName, RelationConfig from dbt.common.exceptions import DbtRuntimeError +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase from dbt.adapters.redshift.relation_configs.dist import ( @@ -95,7 +95,7 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: } @classmethod - def from_dict(cls, config_dict) -> "RedshiftMaterializedViewConfig": + def from_dict(cls, config_dict) -> Self: kwargs_dict = { "mv_name": cls._render_part(ComponentName.Identifier, config_dict.get("mv_name")), "schema_name": cls._render_part(ComponentName.Schema, config_dict.get("schema_name")), @@ -114,39 +114,39 @@ def from_dict(cls, config_dict) -> "RedshiftMaterializedViewConfig": if sort := config_dict.get("sort"): kwargs_dict.update({"sort": RedshiftSortConfig.from_dict(sort)}) - materialized_view: "RedshiftMaterializedViewConfig" = super().from_dict(kwargs_dict) # type: ignore + materialized_view: Self = super().from_dict(kwargs_dict) # type: ignore return materialized_view @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: - config_dict = { - "mv_name": model_node.identifier, - "schema_name": model_node.schema, - "database_name": model_node.database, + def parse_relation_config(cls, config: RelationConfig) -> Dict[str, Any]: + config_dict: Dict[str, Any] = { + "mv_name": config.identifier, + "schema_name": config.schema, + "database_name": config.database, } # backup/autorefresh can be bools or strings - backup_value = model_node.config.extra.get("backup") + backup_value = config.config.extra.get("backup") # type: ignore if backup_value is not None: config_dict["backup"] = evaluate_bool(backup_value) - autorefresh_value = model_node.config.extra.get("auto_refresh") + autorefresh_value = config.config.extra.get("auto_refresh") # type: ignore if autorefresh_value is not None: config_dict["autorefresh"] = evaluate_bool(autorefresh_value) - if query := model_node.compiled_code: + if query := config.compiled_code: # type: ignore config_dict.update({"query": query.strip()}) - if model_node.config.get("dist"): - config_dict.update({"dist": RedshiftDistConfig.parse_model_node(model_node)}) + if config.config.get("dist"): + config_dict.update({"dist": RedshiftDistConfig.parse_relation_config(config)}) - if model_node.config.get("sort"): - config_dict.update({"sort": RedshiftSortConfig.parse_model_node(model_node)}) + if config.config.get("sort"): + config_dict.update({"sort": RedshiftSortConfig.parse_relation_config(config)}) return config_dict @classmethod - def parse_relation_results(cls, relation_results: RelationResults) -> dict: + def parse_relation_results(cls, relation_results: RelationResults) -> Dict: """ Translate agate objects from the database into a standard dictionary. diff --git a/dbt/adapters/redshift/relation_configs/sort.py b/dbt/adapters/redshift/relation_configs/sort.py index f683e7201..c97d137bc 100644 --- a/dbt/adapters/redshift/relation_configs/sort.py +++ b/dbt/adapters/redshift/relation_configs/sort.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Optional, FrozenSet, Set +from dbt.adapters.contracts.relation import RelationConfig +from typing import Optional, FrozenSet, Set, Dict, Any import agate from dbt.adapters.relation_configs import ( @@ -8,9 +9,9 @@ RelationConfigValidationMixin, RelationConfigValidationRule, ) -from dbt.contracts.graph.nodes import ModelNode from dbt.common.dataclass_schema import StrEnum from dbt.common.exceptions import DbtRuntimeError +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase @@ -97,21 +98,21 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: } @classmethod - def from_dict(cls, config_dict) -> "RedshiftSortConfig": + def from_dict(cls, config_dict) -> Self: kwargs_dict = { "sortstyle": config_dict.get("sortstyle"), "sortkey": frozenset(column for column in config_dict.get("sortkey", {})), } - sort: "RedshiftSortConfig" = super().from_dict(kwargs_dict) # type: ignore - return sort + sort: Self = super().from_dict(kwargs_dict) # type: ignore + return sort # type: ignore @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: """ Translate ModelNode objects from the user-provided config into a standard dictionary. Args: - model_node: the description of the sortkey and sortstyle from the user in this format: + relation_config: the description of the sortkey and sortstyle from the user in this format: { "sort_key": "" or [""] or ["",...] @@ -122,10 +123,10 @@ def parse_model_node(cls, model_node: ModelNode) -> dict: """ config_dict = {} - if sortstyle := model_node.config.extra.get("sort_type"): + if sortstyle := relation_config.config.extra.get("sort_type"): # type: ignore config_dict.update({"sortstyle": sortstyle.lower()}) - if sortkey := model_node.config.extra.get("sort"): + if sortkey := relation_config.config.extra.get("sort"): # type: ignore # we allow users to specify the `sort_key` as a string if it's a single column if isinstance(sortkey, str): sortkey = [sortkey] diff --git a/dbt/include/redshift/macros/materializations/materialized_view.sql b/dbt/include/redshift/macros/materializations/materialized_view.sql index 5cdb26504..9b1ef2d50 100644 --- a/dbt/include/redshift/macros/materializations/materialized_view.sql +++ b/dbt/include/redshift/macros/materializations/materialized_view.sql @@ -1,5 +1,5 @@ {% macro redshift__get_materialized_view_configuration_changes(existing_relation, new_config) %} {% set _existing_materialized_view = redshift__describe_materialized_view(existing_relation) %} - {% set _configuration_changes = existing_relation.materialized_view_config_changeset(_existing_materialized_view, new_config) %} + {% set _configuration_changes = existing_relation.materialized_view_config_changeset(_existing_materialized_view, new_config.model) %} {% do return(_configuration_changes) %} {% endmacro %} diff --git a/dbt/include/redshift/macros/relations/materialized_view/create.sql b/dbt/include/redshift/macros/relations/materialized_view/create.sql index b84680525..06fe2b6b5 100644 --- a/dbt/include/redshift/macros/relations/materialized_view/create.sql +++ b/dbt/include/redshift/macros/relations/materialized_view/create.sql @@ -1,6 +1,6 @@ {% macro redshift__get_create_materialized_view_as_sql(relation, sql) %} - {%- set materialized_view = relation.from_runtime_config(config) -%} + {%- set materialized_view = relation.from_config(config.model) -%} create materialized view {{ materialized_view.path }} backup {% if materialized_view.backup %}yes{% else %}no{% endif %} diff --git a/tests/unit/relation_configs/test_materialized_view.py b/tests/unit/relation_configs/test_materialized_view.py index 42a3223d0..5e454fe5e 100644 --- a/tests/unit/relation_configs/test_materialized_view.py +++ b/tests/unit/relation_configs/test_materialized_view.py @@ -17,7 +17,7 @@ def test_redshift_materialized_view_config_handles_all_valid_bools(bool_value): model_node.config.extra.get = ( lambda x, y=None: bool_value if x in ["auto_refresh", "backup"] else "someDistValue" ) - config_dict = config.parse_model_node(model_node) + config_dict = config.parse_relation_config(model_node) assert isinstance(config_dict["autorefresh"], bool) assert isinstance(config_dict["backup"], bool) @@ -37,7 +37,7 @@ def test_redshift_materialized_view_config_throws_expected_exception_with_invali lambda x, y=None: bool_value if x in ["auto_refresh", "backup"] else "someDistValue" ) with pytest.raises(TypeError): - config.parse_model_node(model_node) + config.parse_relation_config(model_node) def test_redshift_materialized_view_config_throws_expected_exception_with_invalid_str(): @@ -52,4 +52,4 @@ def test_redshift_materialized_view_config_throws_expected_exception_with_invali lambda x, y=None: "notABool" if x in ["auto_refresh", "backup"] else "someDistValue" ) with pytest.raises(ValueError): - config.parse_model_node(model_node) + config.parse_relation_config(model_node) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py deleted file mode 100644 index f7bdd7485..000000000 --- a/tests/unit/test_context.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import pytest -import unittest - -from multiprocessing import get_context -from unittest import mock - -from .utils import config_from_parts_or_dicts, inject_adapter, clear_plugin -from .mock_adapter import adapter_factory -import dbt.adapters.exceptions - -from dbt.adapters import ( - redshift, - factory, -) -from dbt.contracts.graph.model_config import ( - NodeConfig, -) -from dbt.contracts.graph.nodes import ModelNode, DependsOn, Macro -from dbt.context import providers -from dbt.node_types import NodeType - - -class TestRuntimeWrapper(unittest.TestCase): - def setUp(self): - self.mock_config = mock.MagicMock() - self.mock_config.quoting = {"database": True, "schema": True, "identifier": True} - adapter_class = adapter_factory() - self.mock_adapter = adapter_class(self.mock_config) - self.namespace = mock.MagicMock() - self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace) - self.responder = self.mock_adapter.responder - - -PROFILE_DATA = { - "target": "test", - "quoting": {}, - "outputs": { - "test": { - "type": "redshift", - "host": "localhost", - "schema": "analytics", - "user": "test", - "pass": "test", - "dbname": "test", - "port": 1, - } - }, -} - - -PROJECT_DATA = { - "name": "root", - "version": "0.1", - "profile": "test", - "project-root": os.getcwd(), - "config-version": 2, -} - - -def model(): - return ModelNode( - alias="model_one", - name="model_one", - database="dbt", - schema="analytics", - resource_type=NodeType.Model, - unique_id="model.root.model_one", - fqn=["root", "model_one"], - package_name="root", - original_file_path="model_one.sql", - root_path="/usr/src/app", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "view", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model_one.sql", - raw_sql="", - description="", - columns={}, - ) - - -def mock_macro(name, package_name): - macro = mock.MagicMock( - __class__=Macro, - package_name=package_name, - resource_type="macro", - unique_id=f"macro.{package_name}.{name}", - ) - # Mock(name=...) does not set the `name` attribute, this does. - macro.name = name - return macro - - -def mock_manifest(config): - manifest_macros = {} - for name in ["macro_a", "macro_b"]: - macro = mock_macro(name, config.project_name) - manifest_macros[macro.unique_id] = macro - return mock.MagicMock(macros=manifest_macros) - - -def mock_model(): - return mock.MagicMock( - __class__=ModelNode, - alias="model_one", - name="model_one", - database="dbt", - schema="analytics", - resource_type=NodeType.Model, - unique_id="model.root.model_one", - fqn=["root", "model_one"], - package_name="root", - original_file_path="model_one.sql", - root_path="/usr/src/app", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "view", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model_one.sql", - raw_sql="", - description="", - columns={}, - defer_relation=None, - ) - - -@pytest.fixture -def get_adapter(): - with mock.patch.object(providers, "get_adapter") as patch: - yield patch - - -@pytest.fixture -def get_include_paths(): - with mock.patch.object(factory, "get_include_paths") as patch: - patch.return_value = [] - yield patch - - -@pytest.fixture -def config(): - return config_from_parts_or_dicts(PROJECT_DATA, PROFILE_DATA) - - -@pytest.fixture -def manifest_fx(config): - return mock_manifest(config) - - -@pytest.fixture -def manifest_extended(manifest_fx): - dbt_macro = mock_macro("default__some_macro", "dbt") - # same namespace, same name, different pkg! - rs_macro = mock_macro("redshift__some_macro", "dbt_redshift") - # same name, different package - package_default_macro = mock_macro("default__some_macro", "root") - package_rs_macro = mock_macro("redshift__some_macro", "root") - manifest_fx.macros[dbt_macro.unique_id] = dbt_macro - manifest_fx.macros[rs_macro.unique_id] = rs_macro - manifest_fx.macros[package_default_macro.unique_id] = package_default_macro - manifest_fx.macros[package_rs_macro.unique_id] = package_rs_macro - return manifest_fx - - -@pytest.fixture -def redshift_adapter(config, get_adapter): - adapter = redshift.RedshiftAdapter(config, get_context("spawn")) - inject_adapter(adapter, redshift.Plugin) - get_adapter.return_value = adapter - yield adapter - clear_plugin(redshift.Plugin) - - -def test_resolve_specific(config, manifest_extended, redshift_adapter, get_include_paths): - rs_macro = manifest_extended.macros["macro.dbt_redshift.redshift__some_macro"] - package_rs_macro = manifest_extended.macros["macro.root.redshift__some_macro"] - - ctx = providers.generate_runtime_model_context( - model=mock_model(), - config=config, - manifest=manifest_extended, - ) - - ctx["adapter"].config.dispatch - - # macro_a exists, but default__macro_a and redshift__macro_a do not - with pytest.raises(dbt.exceptions.CompilationError): - ctx["adapter"].dispatch("macro_a").macro - - # root namespace is always preferred, unless search order is explicitly defined in 'dispatch' config - assert ctx["adapter"].dispatch("some_macro").macro is package_rs_macro - assert ctx["adapter"].dispatch("some_macro", "dbt").macro is package_rs_macro - assert ctx["adapter"].dispatch("some_macro", "root").macro is package_rs_macro - - # override 'dbt' namespace search order, dispatch to 'root' first - ctx["adapter"].config.dispatch = [{"macro_namespace": "dbt", "search_order": ["root", "dbt"]}] - assert ctx["adapter"].dispatch("some_macro", macro_namespace="dbt").macro is package_rs_macro - - # override 'dbt' namespace search order, dispatch to 'dbt' only - ctx["adapter"].config.dispatch = [{"macro_namespace": "dbt", "search_order": ["dbt"]}] - assert ctx["adapter"].dispatch("some_macro", macro_namespace="dbt").macro is rs_macro - - # override 'root' namespace search order, dispatch to 'dbt' first - ctx["adapter"].config.dispatch = [{"macro_namespace": "root", "search_order": ["dbt", "root"]}]