From c7b9b1a2091555d3a4a15b9e00c02dbac40b36d0 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 8 Dec 2023 03:21:32 +0900 Subject: [PATCH] pass context to MacroQueryStringSetter (#9248) * moving types_pb2.py to common/events * remove manifest from adapter.execute_macro, replace with MacroResolver + remove lazy loading * rename to MacroResolverProtocol * pass MacroResolverProtcol in adapter.calculate_freshness_from_metadata * changelog entry * fix adapter.calculate_freshness call * pass context to MacroQueryStringSetter * changelog entry --------- Co-authored-by: Colin --- .../Under the Hood-20231207-224139.yaml | 6 ++++++ core/dbt/adapters/base/connections.py | 5 ++--- core/dbt/adapters/base/query_headers.py | 11 +++++------ core/dbt/adapters/protocol.py | 16 +++++++++++++--- core/dbt/parser/manifest.py | 15 ++++++++------- tests/unit/test_postgres_adapter.py | 7 ++++++- tests/unit/test_query_headers.py | 9 +++++++-- 7 files changed, 47 insertions(+), 22 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20231207-224139.yaml diff --git a/.changes/unreleased/Under the Hood-20231207-224139.yaml b/.changes/unreleased/Under the Hood-20231207-224139.yaml new file mode 100644 index 00000000000..8c4f4fd3c1f --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231207-224139.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: pass query header context to MacroQueryStringSetter +time: 2023-12-07T22:41:39.498024+09:00 +custom: + Author: michelleark + Issue: 9249 9250 diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 80ebf322523..f347876f62e 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -34,7 +34,6 @@ LazyHandle, AdapterResponse, ) -from dbt.contracts.graph.manifest import Manifest from dbt.adapters.base.query_headers import ( MacroQueryStringSetter, ) @@ -79,8 +78,8 @@ def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> self.lock: RLock = mp_context.RLock() self.query_header: Optional[MacroQueryStringSetter] = None - def set_query_header(self, manifest: Manifest) -> None: - self.query_header = MacroQueryStringSetter(self.profile, manifest) + def set_query_header(self, query_header_context: Dict[str, Any]) -> None: + self.query_header = MacroQueryStringSetter(self.profile, query_header_context) @staticmethod def get_thread_identifier() -> Hashable: diff --git a/core/dbt/adapters/base/query_headers.py b/core/dbt/adapters/base/query_headers.py index 6fa591d45c8..b5f64d6214c 100644 --- a/core/dbt/adapters/base/query_headers.py +++ b/core/dbt/adapters/base/query_headers.py @@ -2,10 +2,7 @@ from typing import Optional, Callable, Dict, Any from dbt.adapters.clients.jinja import QueryStringGenerator - -from dbt.context.manifest import generate_query_header_context from dbt.adapters.contracts.connection import AdapterRequiredConfig, QueryComment -from dbt.contracts.graph.manifest import Manifest from dbt.common.exceptions import DbtRuntimeError @@ -56,9 +53,11 @@ def set(self, comment: Optional[str], append: bool): class MacroQueryStringSetter: - def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None: - self.manifest = manifest + def __init__( + self, config: AdapterRequiredConfig, query_header_context: Dict[str, Any] + ) -> None: self.config = config + self._query_header_context = query_header_context comment_macro = self._get_comment_macro() self.generator: QueryStringFunc = lambda name, model: "" @@ -81,7 +80,7 @@ def _get_comment_macro(self) -> Optional[str]: return self.config.query_comment.comment def _get_context(self) -> Dict[str, Any]: - return generate_query_header_context(self.config, self.manifest) + return self._query_header_context def add(self, sql: str) -> str: return self.comment.add(sql) diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 67a2bc9b998..201073201cf 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -1,5 +1,16 @@ from dataclasses import dataclass -from typing import Type, Hashable, Optional, ContextManager, List, Generic, TypeVar, Tuple, Any +from typing import ( + Type, + Hashable, + Optional, + ContextManager, + List, + Generic, + TypeVar, + Tuple, + Any, + Dict, +) from typing_extensions import Protocol import agate @@ -8,7 +19,6 @@ from dbt.adapters.contracts.macros import MacroResolverProtocol from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.common.contracts.config.base import BaseConfig -from dbt.contracts.graph.manifest import Manifest @dataclass @@ -80,7 +90,7 @@ def clear_macro_resolver(self) -> None: def type(cls) -> str: pass - def set_query_header(self, manifest: Manifest) -> None: + def set_query_header(self, query_header_context: Dict[str, Any]) -> None: ... @staticmethod diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index c952db063f4..5cf3a0360da 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -72,6 +72,7 @@ ) from dbt.config import Project, RuntimeConfig from dbt.context.docs import generate_runtime_docs_context +from dbt.context.manifest import generate_query_header_context from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace from dbt.context.configured import generate_macro_context from dbt.context.providers import ParseProvider @@ -236,7 +237,7 @@ def __init__( self, root_project: RuntimeConfig, all_projects: Mapping[str, Project], - macro_hook: Optional[Callable[[Manifest], Any]] = None, + macro_hook: Optional[Callable[[Dict[str, Any]], Any]] = None, file_diff: Optional[FileDiff] = None, ) -> None: self.root_project: RuntimeConfig = root_project @@ -250,9 +251,9 @@ def __init__( # This is a MacroQueryStringSetter callable, which is called # later after we set the MacroManifest in the adapter. It sets # up the query headers. - self.macro_hook: Callable[[Manifest], Any] + self.macro_hook: Callable[[Dict[str, Any]], Any] if macro_hook is None: - self.macro_hook = lambda m: None + self.macro_hook = lambda c: None else: self.macro_hook = macro_hook @@ -1001,9 +1002,9 @@ def build_manifest_state_check(self): def save_macros_to_adapter(self, adapter): macro_manifest = MacroManifest(self.manifest.macros) adapter.set_macro_resolver(macro_manifest) - # This executes the callable macro_hook and sets the - # query headers - self.macro_hook(macro_manifest) + # This executes the callable macro_hook and sets the query headers + query_header_context = generate_query_header_context(adapter.config, macro_manifest) + self.macro_hook(query_header_context) # This creates a MacroManifest which contains the macros in # the adapter. Only called by the load_macros call from the @@ -1032,7 +1033,7 @@ def create_macro_manifest(self): def load_macros( cls, root_config: RuntimeConfig, - macro_hook: Callable[[Manifest], Any], + macro_hook: Callable[[Dict[str, Any]], Any], base_macros_only=False, ) -> Manifest: with PARSING_STATE: diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index 31696f7a9cb..02c3b7a8335 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -15,6 +15,7 @@ from dbt.adapters.postgres import Plugin as PostgresPlugin from dbt.contracts.files import FileHash from dbt.contracts.graph.manifest import ManifestStateCheck +from dbt.context.manifest import generate_query_header_context from dbt.common.clients import agate_helper from dbt.exceptions import DbtConfigError from dbt.common.exceptions import DbtValidationError @@ -429,8 +430,12 @@ def _mock_state_check(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config, self.mp_context) self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config)) + + query_header_context = generate_query_header_context( + self.adapter.config, self.adapter.get_macro_resolver() + ) self.adapter.connections.query_header = MacroQueryStringSetter( - self.config, self.adapter.get_macro_resolver() + self.config, query_header_context ) self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add") diff --git a/tests/unit/test_query_headers.py b/tests/unit/test_query_headers.py index dd78fc0c838..2be9b59bd4d 100644 --- a/tests/unit/test_query_headers.py +++ b/tests/unit/test_query_headers.py @@ -2,6 +2,7 @@ from unittest import TestCase, mock from dbt.adapters.base.query_headers import MacroQueryStringSetter +from dbt.context.manifest import generate_query_header_context from tests.unit.utils import config_from_parts_or_dicts from dbt.flags import set_from_args @@ -36,14 +37,18 @@ def setUp(self): def test_comment_should_prepend_query_by_default(self): config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) - query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={})) + + query_header_context = generate_query_header_context(config, mock.MagicMock(macros={})) + query_header = MacroQueryStringSetter(config, query_header_context) sql = query_header.add(self.query) self.assertTrue(re.match(f"^\/\*.*\*\/\n{self.query}$", sql)) # noqa: [W605] def test_append_comment(self): self.project_cfg.update({"query-comment": {"comment": "executed by dbt", "append": True}}) config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) - query_header = MacroQueryStringSetter(config, mock.MagicMock(macros={})) + + query_header_context = generate_query_header_context(config, mock.MagicMock(macros={})) + query_header = MacroQueryStringSetter(config, query_header_context) sql = query_header.add(self.query) self.assertEqual(sql, f"{self.query[:-1]}\n/* executed by dbt */;")