Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #8031: Call materialization macro from adapter dispatch #8355

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230810-183216.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Call materialization macro from adapter dispatch
time: 2023-08-10T18:32:16.226142+01:00
custom:
Author: aranke
Issue: "8031"
5 changes: 3 additions & 2 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,11 @@ def depth(self) -> int:
def push(self, name):
self.call_stack.append(name)

def pop(self, name):
def pop(self, name: Optional[str] = None):
got = self.call_stack.pop()
if got != name:
if name and got != name:
raise DbtInternalError(f"popped {got}, expected {name}")
return got


class MacroGenerator(BaseMacroGenerator):
Expand Down
65 changes: 37 additions & 28 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
Type,
Iterable,
Mapping,
Tuple,
)

import agate
from typing_extensions import Protocol

from dbt import selected_resources
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.config import IsFQNResource
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.context.base import contextmember, contextproperty, Var
Expand All @@ -29,6 +34,7 @@
from dbt.context.manifest import ManifestContext
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.nodes import (
Macro,
Exposure,
Expand All @@ -40,7 +46,6 @@
AccessType,
SemanticModel,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
from dbt.events.functions import get_metadata_vars
from dbt.exceptions import (
Expand Down Expand Up @@ -69,16 +74,9 @@
DbtValidationError,
DbtReferenceError,
)
from dbt.config import IsFQNResource
from dbt.node_types import NodeType, ModelLanguage

from dbt.utils import merge, AttrDict, MultiDict, args_to_dict, cast_to_str

from dbt import selected_resources

import agate


_MISSING = object()


Expand Down Expand Up @@ -156,6 +154,7 @@ def dispatch(
self,
macro_name: str,
macro_namespace: Optional[str] = None,
stack: Optional[MacroStack] = None,
packages: Optional[List[str]] = None, # eventually remove since it's fully deprecated
) -> MacroGenerator:
search_packages: List[Optional[str]]
Expand All @@ -174,30 +173,40 @@ def dispatch(
raise MacroDispatchArgError(macro_name)

search_packages = self._get_search_packages(macro_namespace)

attempts = []
macro = None
potential_macros: List[Tuple[Optional[str], str]] = []
failed_macros: List[Tuple[Optional[str], str]] = []

for package_name in search_packages:
for prefix in self._get_adapter_macro_prefixes():
search_name = f"{prefix}__{macro_name}"
try:
# this uses the namespace from the context
macro = self._namespace.get_from_package(package_name, search_name)
except CompilationError:
# Only raise CompilationError if macro is not found in
# any package
macro = None

if package_name is None:
attempts.append(search_name)
else:
attempts.append(f"{package_name}.{search_name}")

if macro is not None:
if macro_name.startswith("materialization_"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we made a fix recently for scenarios where macros that started with "materialization_" were incorrectly being picked up as materializations. For example, this was being flagged as a materialization instead of a macro:

{% macro materialization_setup() %}
    ...
{% endmacro %}

Have you verified this doesn't reintroduce that issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tests for that case now, so this PR shouldn't reintroduce the issue (or if it does, it'll be obvious before merging).

#8181

potential_macros.append((package_name, macro_name))
potential_macros.append(("dbt", macro_name))
else:
for prefix in self._get_adapter_macro_prefixes():
potential_macros.append((package_name, f"{prefix}__{macro_name}"))
potential_macros.append(
(package_name, f"materialization_{macro_name}_{prefix}")
)

for package_name, search_name in potential_macros:
try:
macro = self._namespace.get_from_package(package_name, search_name)
if macro:
macro.stack = stack
except CompilationError:
# Only raise CompilationError if macro is not found in
# any package
pass
finally:
if macro:
return macro
else:
failed_macros.append((package_name, search_name))

searched = ", ".join(repr(a) for a in attempts)
msg = f"In dispatch: No macro named '{macro_name}' found within namespace: '{macro_namespace}'\n Searched for: {searched}"
msg = (
f"In dispatch: No macro named '{macro_name}' found within namespace: '{macro_namespace}'\n"
f"Searched for: {failed_macros}"
)
raise CompilationError(msg)


Expand Down
5 changes: 3 additions & 2 deletions core/dbt/task/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import AbstractSet, Any, List, Iterable, Set

from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.results import RunStatus, RunResult
from dbt.dataclass_schema import dbtClassMixin
Expand Down Expand Up @@ -80,7 +79,9 @@ def execute(self, model, manifest):

hook_ctx = self.adapter.pre_model_hook(context_config)
try:
result = MacroGenerator(materialization_macro, context)()
result = context["adapter"].dispatch(
materialization_macro.name, stack=context["context_macro_stack"]
)()
finally:
self.adapter.post_model_hook(context_config, hook_ctx)

Expand Down
42 changes: 19 additions & 23 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
import functools
import threading
import time
from datetime import datetime
from typing import List, Dict, Any, Iterable, Set, Tuple, Optional, AbstractSet

from dbt.dataclass_schema import dbtClassMixin

from .compile import CompileRunner, CompileTask

from .printer import (
print_run_end_messages,
get_counts,
)
from datetime import datetime
from dbt import tracking
from dbt import utils
from dbt.adapters.base import BaseRelation
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.model_config import Hook
from dbt.contracts.graph.nodes import HookNode, ResultNode
from dbt.contracts.results import NodeStatus, RunResult, RunStatus, RunningStatus, BaseResult
from dbt.exceptions import (
CompilationError,
DbtInternalError,
MissingMaterializationError,
DbtRuntimeError,
DbtValidationError,
)
from dbt.dataclass_schema import dbtClassMixin
from dbt.events.base_types import EventLevel
from dbt.events.functions import fire_event, get_invocation_id
from dbt.events.types import (
DatabaseErrorRunningHook,
Expand All @@ -38,17 +24,28 @@
LogHookEndLine,
LogHookStartLine,
)
from dbt.events.base_types import EventLevel
from dbt.exceptions import (
CompilationError,
DbtInternalError,
MissingMaterializationError,
DbtRuntimeError,
DbtValidationError,
)
from dbt.graph import ResourceTypeSelector
from dbt.hooks import get_hook_dict
from dbt.logger import (
TextOnly,
HookMetadata,
UniqueID,
TimestampNamed,
DbtModelState,
)
from dbt.graph import ResourceTypeSelector
from dbt.hooks import get_hook_dict
from dbt.node_types import NodeType, RunHookType
from dbt.task.compile import CompileRunner, CompileTask
from dbt.task.printer import (
print_run_end_messages,
get_counts,
)


class Timer:
Expand Down Expand Up @@ -288,8 +285,8 @@ def execute(self, model, manifest):

hook_ctx = self.adapter.pre_model_hook(context_config)
try:
result = MacroGenerator(
materialization_macro, context, stack=context["context_macro_stack"]
result = context["adapter"].dispatch(
materialization_macro.name, stack=context["context_macro_stack"]
)()
finally:
self.adapter.post_model_hook(context_config, hook_ctx)
Expand Down Expand Up @@ -327,7 +324,6 @@ def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]:
return package_name, hook.index

def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]:

if self.manifest is None:
raise DbtInternalError("self.manifest was None in get_hooks_by_type")

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/materializations/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
"""

override_view_default_dep__macros__default_view_sql = """
{%- materialization view, default -%}
{%- materialization view, adapter = 'postgres' -%}
{{ exceptions.raise_compiler_error('intentionally raising an error in the default view materialization') }}
{%- endmaterialization -%}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from dbt.tests.util import run_dbt

parent_materialization = """
{% materialization parent, default %}
{%- set target_relation = this.incorporate(type='table') %}
{% call statement('main') -%}
set session time zone 'Asia/Kolkata';
{%- endcall %}
{{ return({'relations': [target_relation]}) }}
{% endmaterialization %}
"""

child_materialization = """
{% materialization child, default %}
{%- set relations = adapter.dispatch('parent')() %}
{{ return({'relations': relations['relations'] }) }}
{% endmaterialization %}
"""

my_model_sql = """
{{ config(materialized='child') }}
select current_setting('timezone') as current_tz
"""


class TestMaterializationOverride:
@pytest.fixture(scope="class")
def macros(self):
return {
"parent.sql": parent_materialization,
"child.sql": child_materialization,
}

@pytest.fixture(scope="class")
def models(self):
return {
"model.sql": my_model_sql,
}

def test_foo(self, project):
res = run_dbt(["run"])
print(res)
17 changes: 17 additions & 0 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,23 @@ def test_macro_namespace_duplicates(config_postgres, manifest_fx):
mn.add_macros(mock_macro("macro_a", "dbt"), {})


def test_macro_stack(config_postgres, manifest_fx):
stack = MacroStack()
stack.push("foo")
stack.push("bar")
mn = macros.MacroNamespaceBuilder("root", "search", stack, ["dbt_postgres", "dbt"])
mn.add_macros(manifest_fx.macros.values(), {})

stack = mn.thread_ctx
assert stack.depth == 2
assert stack.pop() == "bar"

with pytest.raises(dbt.exceptions.DbtInternalError):
stack.pop("bar")

assert stack.depth == 0


def test_macro_namespace(config_postgres, manifest_fx):
mn = macros.MacroNamespaceBuilder("root", "search", MacroStack(), ["dbt_postgres", "dbt"])

Expand Down