Skip to content

Commit

Permalink
pass context to MacroQueryStringSetter (#9248)
Browse files Browse the repository at this point in the history
* moving types_pb2.py to common/events

* remove manifest from adapter.execute_macro, replace with MacroResolver + remove lazy loading

* rename to MacroResolverProtocol

* pass MacroResolverProtcol in adapter.calculate_freshness_from_metadata

* changelog entry

* fix adapter.calculate_freshness call

* pass context to MacroQueryStringSetter

* changelog entry

---------

Co-authored-by: Colin <[email protected]>
  • Loading branch information
MichelleArk and colin-rogers-dbt authored Dec 7, 2023
1 parent 4d16524 commit c7b9b1a
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 22 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231207-224139.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: pass query header context to MacroQueryStringSetter
time: 2023-12-07T22:41:39.498024+09:00
custom:
Author: michelleark
Issue: 9249 9250
5 changes: 2 additions & 3 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
LazyHandle,
AdapterResponse,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.adapters.base.query_headers import (
MacroQueryStringSetter,
)
Expand Down Expand Up @@ -79,8 +78,8 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) ->
self.lock: RLock = mp_context.RLock()
self.query_header: Optional[MacroQueryStringSetter] = None

def set_query_header(self, manifest: Manifest) -> None:
self.query_header = MacroQueryStringSetter(self.profile, manifest)
def set_query_header(self, query_header_context: Dict[str, Any]) -> None:
self.query_header = MacroQueryStringSetter(self.profile, query_header_context)

@staticmethod
def get_thread_identifier() -> Hashable:
Expand Down
11 changes: 5 additions & 6 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from typing import Optional, Callable, Dict, Any

from dbt.adapters.clients.jinja import QueryStringGenerator

from dbt.context.manifest import generate_query_header_context
from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment
from dbt.contracts.graph.manifest import Manifest
from dbt.common.exceptions import DbtRuntimeError


Expand Down Expand Up @@ -56,9 +53,11 @@ def set(self, comment: Optional[str], append: bool):


class MacroQueryStringSetter:
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
self.manifest = manifest
def __init__(
self, config: AdapterRequiredConfig, query_header_context: Dict[str, Any]
) -> None:
self.config = config
self._query_header_context = query_header_context

comment_macro = self._get_comment_macro()
self.generator: QueryStringFunc = lambda name, model: ""
Expand All @@ -81,7 +80,7 @@ def _get_comment_macro(self) -> Optional[str]:
return self.config.query_comment.comment

def _get_context(self) -> Dict[str, Any]:
return generate_query_header_context(self.config, self.manifest)
return self._query_header_context

def add(self, sql: str) -> str:
return self.comment.add(sql)
Expand Down
16 changes: 13 additions & 3 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from dataclasses import dataclass
from typing import Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, Tuple, Any
from typing import (
Type,
Hashable,
Optional,
ContextManager,
List,
Generic,
TypeVar,
Tuple,
Any,
Dict,
)
from typing_extensions import Protocol

import agate
Expand All @@ -8,7 +19,6 @@
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.common.contracts.config.base import BaseConfig
from dbt.contracts.graph.manifest import Manifest


@dataclass
Expand Down Expand Up @@ -80,7 +90,7 @@ def clear_macro_resolver(self) -> None:
def type(cls) -> str:
pass

def set_query_header(self, manifest: Manifest) -> None:
def set_query_header(self, query_header_context: Dict[str, Any]) -> None:
...

@staticmethod
Expand Down
15 changes: 8 additions & 7 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)
from dbt.config import Project, RuntimeConfig
from dbt.context.docs import generate_runtime_docs_context
from dbt.context.manifest import generate_query_header_context
from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace
from dbt.context.configured import generate_macro_context
from dbt.context.providers import ParseProvider
Expand Down Expand Up @@ -236,7 +237,7 @@ def __init__(
self,
root_project: RuntimeConfig,
all_projects: Mapping[str, Project],
macro_hook: Optional[Callable[[Manifest], Any]] = None,
macro_hook: Optional[Callable[[Dict[str, Any]], Any]] = None,
file_diff: Optional[FileDiff] = None,
) -> None:
self.root_project: RuntimeConfig = root_project
Expand All @@ -250,9 +251,9 @@ def __init__(
# This is a MacroQueryStringSetter callable, which is called
# later after we set the MacroManifest in the adapter. It sets
# up the query headers.
self.macro_hook: Callable[[Manifest], Any]
self.macro_hook: Callable[[Dict[str, Any]], Any]
if macro_hook is None:
self.macro_hook = lambda m: None
self.macro_hook = lambda c: None
else:
self.macro_hook = macro_hook

Expand Down Expand Up @@ -1001,9 +1002,9 @@ def build_manifest_state_check(self):
def save_macros_to_adapter(self, adapter):
macro_manifest = MacroManifest(self.manifest.macros)
adapter.set_macro_resolver(macro_manifest)
# This executes the callable macro_hook and sets the
# query headers
self.macro_hook(macro_manifest)
# This executes the callable macro_hook and sets the query headers
query_header_context = generate_query_header_context(adapter.config, macro_manifest)
self.macro_hook(query_header_context)

# This creates a MacroManifest which contains the macros in
# the adapter. Only called by the load_macros call from the
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def create_macro_manifest(self):
def load_macros(
cls,
root_config: RuntimeConfig,
macro_hook: Callable[[Manifest], Any],
macro_hook: Callable[[Dict[str, Any]], Any],
base_macros_only=False,
) -> Manifest:
with PARSING_STATE:
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dbt.adapters.postgres import Plugin as PostgresPlugin
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import ManifestStateCheck
from dbt.context.manifest import generate_query_header_context
from dbt.common.clients import agate_helper
from dbt.exceptions import DbtConfigError
from dbt.common.exceptions import DbtValidationError
Expand Down Expand Up @@ -429,8 +430,12 @@ def _mock_state_check(self):
self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config))

query_header_context = generate_query_header_context(
self.adapter.config, self.adapter.get_macro_resolver()
)
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter.get_macro_resolver()
self.config, query_header_context
)

self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/test_query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import TestCase, mock

from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.context.manifest import generate_query_header_context

from tests.unit.utils import config_from_parts_or_dicts
from dbt.flags import set_from_args
Expand Down Expand Up @@ -36,14 +37,18 @@ def setUp(self):

def test_comment_should_prepend_query_by_default(self):
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={}))

query_header_context = generate_query_header_context(config, mock.MagicMock(macros={}))
query_header = MacroQueryStringSetter(config, query_header_context)
sql = query_header.add(self.query)
self.assertTrue(re.match(f"^\/\*.*\*\/\n{self.query}$", sql)) # noqa: [W605]

def test_append_comment(self):
self.project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}})
config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={}))

query_header_context = generate_query_header_context(config, mock.MagicMock(macros={}))
query_header = MacroQueryStringSetter(config, query_header_context)
sql = query_header.add(self.query)
self.assertEqual(sql, f"{self.query[:-1]}\n/* executed by dbt */;")

Expand Down

0 comments on commit c7b9b1a

Please sign in to comment.