diff --git a/.changes/unreleased/Under the Hood-20231207-111554.yaml b/.changes/unreleased/Under the Hood-20231207-111554.yaml new file mode 100644 index 00000000000..8dec8ed18e4 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231207-111554.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add MacroResolverProtocol, remove lazy loading of manifest in adapter.execute_macro +time: 2023-12-07T11:15:54.427818+09:00 +custom: + Author: michelleark + Issue: "9244" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 92977d0b4d6..b47cecc8d30 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -8,7 +8,6 @@ Any, Callable, Dict, - Iterable, Iterator, List, Mapping, @@ -19,6 +18,7 @@ TypedDict, Union, FrozenSet, + Iterable, ) from multiprocessing.context import SpawnContext @@ -28,6 +28,7 @@ ConstraintType, ModelLevelConstraint, ) +from dbt.adapters.contracts.macros import MacroResolverProtocol import agate import pytz @@ -62,7 +63,6 @@ Integer, ) from dbt.common.clients.jinja import CallableMacroGenerator -from dbt.contracts.graph.manifest import Manifest, MacroManifest from dbt.common.events.functions import fire_event, warn_or_error from dbt.adapters.events.types import ( CacheMiss, @@ -257,7 +257,20 @@ def __init__(self, config, mp_context: SpawnContext) -> None: self.config = config self.cache = RelationsCache(log_cache_events=config.log_cache_events) self.connections = self.ConnectionManager(config, mp_context) - self._macro_manifest_lazy: Optional[MacroManifest] = None + self._macro_resolver: Optional[MacroResolverProtocol] = None + + ### + # Methods to set / access a macro resolver + ### + def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None: + self._macro_resolver = macro_resolver + + def get_macro_resolver(self) -> Optional[MacroResolverProtocol]: + return self._macro_resolver + + def clear_macro_resolver(self) -> None: + if self._macro_resolver is not None: + self._macro_resolver = None ### # Methods that pass through to the connection manager @@ -370,39 +383,6 @@ def type(cls) -> str: """ return cls.ConnectionManager.TYPE - @property - def _macro_manifest(self) -> MacroManifest: - if self._macro_manifest_lazy is None: - return self.load_macro_manifest() - return self._macro_manifest_lazy - - def check_macro_manifest(self) -> Optional[MacroManifest]: - """Return the internal manifest (used for executing macros) if it's - been initialized, otherwise return None. - """ - return self._macro_manifest_lazy - - def load_macro_manifest(self, base_macros_only=False) -> MacroManifest: - # base_macros_only is for the test framework - if self._macro_manifest_lazy is None: - # avoid a circular import - from dbt.parser.manifest import ManifestLoader - - manifest = ManifestLoader.load_macros( - self.config, - self.connections.set_query_header, - base_macros_only=base_macros_only, - ) - # TODO CT-211 - self._macro_manifest_lazy = manifest # type: ignore[assignment] - # TODO CT-211 - return self._macro_manifest_lazy # type: ignore[return-value] - - def clear_macro_manifest(self): - if self._macro_manifest_lazy is not None: - self._macro_manifest_lazy = None - - ### # Caching methods ### def _schema_is_cached(self, database: Optional[str], schema: str) -> bool: @@ -1052,11 +1032,10 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[ def execute_macro( self, macro_name: str, - manifest: Optional[Manifest] = None, + macro_resolver: Optional[MacroResolverProtocol] = None, project: Optional[str] = None, context_override: Optional[Dict[str, Any]] = None, kwargs: Optional[Dict[str, Any]] = None, - text_only_columns: Optional[Iterable[str]] = None, ) -> AttrDict: """Look macro_name up in the manifest and execute its results. @@ -1076,13 +1055,11 @@ def execute_macro( if context_override is None: context_override = {} - if manifest is None: - # TODO CT-211 - manifest = self._macro_manifest # type: ignore[assignment] - # TODO CT-211 - macro = manifest.find_macro_by_name( # type: ignore[union-attr] - macro_name, self.config.project_name, project - ) + resolver = macro_resolver or self._macro_resolver + if resolver is None: + raise DbtInternalError("macro resolver was None when calling execute_macro!") + + macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project) if macro is None: if project is None: package_name = "any package" @@ -1102,7 +1079,7 @@ def execute_macro( # TODO CT-211 macro=macro, config=self.config, - manifest=manifest, # type: ignore[arg-type] + manifest=resolver, # type: ignore[arg-type] package_name=project, ) macro_context.update(context_override) @@ -1135,10 +1112,7 @@ def _get_one_catalog( used_schemas: FrozenSet[Tuple[str, str]], ) -> agate.Table: kwargs = {"information_schema": information_schema, "schemas": schemas} - table = self.execute_macro( - GET_CATALOG_MACRO_NAME, - kwargs=kwargs, - ) + table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs) results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type] return results @@ -1154,10 +1128,7 @@ def _get_one_catalog_by_relations( "information_schema": information_schema, "relations": relations, } - table = self.execute_macro( - GET_CATALOG_RELATIONS_MACRO_NAME, - kwargs=kwargs, - ) + table = self.execute_macro(GET_CATALOG_RELATIONS_MACRO_NAME, kwargs=kwargs) results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type] return results @@ -1258,7 +1229,7 @@ def calculate_freshness( source: BaseRelation, loaded_at_field: str, filter: Optional[str], - manifest: Optional[Manifest] = None, + macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: """Calculate the freshness of sources in dbt, and return it""" kwargs: Dict[str, Any] = { @@ -1274,7 +1245,9 @@ def calculate_freshness( AttrDict, # current: contains AdapterResponse + agate.Table agate.Table, # previous: just table ] - result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest) + result = self.execute_macro( + FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver + ) if isinstance(result, agate.Table): warn_or_error(CollectFreshnessReturnSignature()) adapter_response = None @@ -1304,14 +1277,14 @@ def calculate_freshness( def calculate_freshness_from_metadata( self, source: BaseRelation, - manifest: Optional[Manifest] = None, + macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: kwargs: Dict[str, Any] = { "information_schema": source.information_schema_only(), "relations": [source], } result = self.execute_macro( - GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, manifest=manifest + GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver ) adapter_response, table = result.response, result.table # type: ignore[attr-defined] diff --git a/core/dbt/adapters/contracts/macros.py b/core/dbt/adapters/contracts/macros.py new file mode 100644 index 00000000000..151c9c44dde --- /dev/null +++ b/core/dbt/adapters/contracts/macros.py @@ -0,0 +1,11 @@ +from typing import Optional +from typing_extensions import Protocol + +from dbt.common.clients.jinja import MacroProtocol + + +class MacroResolverProtocol(Protocol): + def find_macro_by_name( + self, name: str, root_project_name: str, package: Optional[str] + ) -> Optional[MacroProtocol]: + raise NotImplementedError("find_macro_by_name not implemented") diff --git a/core/dbt/adapters/protocol.py b/core/dbt/adapters/protocol.py index 3555bea5c9e..67a2bc9b998 100644 --- a/core/dbt/adapters/protocol.py +++ b/core/dbt/adapters/protocol.py @@ -5,6 +5,7 @@ import agate from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse +from dbt.adapters.contracts.macros import MacroResolverProtocol from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig from dbt.common.contracts.config.base import BaseConfig from dbt.contracts.graph.manifest import Manifest @@ -66,6 +67,15 @@ class AdapterProtocol( # type: ignore[misc] def __init__(self, config: AdapterRequiredConfig) -> None: ... + def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None: + ... + + def get_macro_resolver(self) -> Optional[MacroResolverProtocol]: + ... + + def clear_macro_resolver(self) -> None: + ... + @classmethod def type(cls) -> str: pass diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index c390d4101d8..1f99665af7e 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -15,6 +15,7 @@ from typing_extensions import Protocol from dbt.adapters.base.column import Column +from dbt.common.clients.jinja import MacroProtocol from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names from dbt.common.clients import agate_helper from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack @@ -1363,7 +1364,7 @@ class MacroContext(ProviderContext): def __init__( self, - model: Macro, + model: MacroProtocol, config: RuntimeConfig, manifest: Manifest, provider: Provider, @@ -1520,7 +1521,7 @@ def generate_runtime_model_context( def generate_runtime_macro_context( - macro: Macro, + macro: MacroProtocol, config: RuntimeConfig, manifest: Manifest, package_name: Optional[str], diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index e973b5f3592..c952db063f4 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -286,7 +286,7 @@ def get_full_manifest( # the config and adapter may be persistent. if reset: config.clear_dependencies() - adapter.clear_macro_manifest() + adapter.clear_macro_resolver() macro_hook = adapter.connections.set_query_header flags = get_flags() @@ -1000,7 +1000,7 @@ def build_manifest_state_check(self): def save_macros_to_adapter(self, adapter): macro_manifest = MacroManifest(self.manifest.macros) - adapter._macro_manifest_lazy = macro_manifest + adapter.set_macro_resolver(macro_manifest) # This executes the callable macro_hook and sets the # query headers self.macro_hook(macro_manifest) diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index 3f76d751a91..8cb15756973 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -111,7 +111,7 @@ def execute(self, compiled_node, manifest): relation, compiled_node.loaded_at_field, compiled_node.freshness.filter, - manifest=manifest, + macro_resolver=manifest, ) status = compiled_node.freshness.status(freshness["age"]) @@ -126,7 +126,7 @@ def execute(self, compiled_node, manifest): adapter_response, freshness = self.adapter.calculate_freshness_from_metadata( relation, - manifest=manifest, + macro_resolver=manifest, ) status = compiled_node.freshness.status(freshness["age"]) diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index caa1f1c7b7e..379d5ec6ab8 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -41,7 +41,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table: with adapter.connection_named("macro_{}".format(macro_name)): adapter.clear_transaction() res = adapter.execute_macro( - macro_name, project=package_name, kwargs=macro_kwargs, manifest=self.manifest + macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest ) return res diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index d6d140898a9..961a36c6127 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -27,7 +27,7 @@ def execute(self, compiled_node, manifest): model_context = generate_runtime_model_context(compiled_node, self.config, manifest) compiled_node.compiled_code = self.adapter.execute_macro( macro_name="get_show_sql", - manifest=manifest, + macro_resolver=manifest, context_override=model_context, kwargs={ "compiled_code": model_context["compiled_code"], diff --git a/core/dbt/tests/fixtures/project.py b/core/dbt/tests/fixtures/project.py index 429207e907d..487450dcb45 100644 --- a/core/dbt/tests/fixtures/project.py +++ b/core/dbt/tests/fixtures/project.py @@ -7,6 +7,7 @@ import warnings import yaml +from dbt.parser.manifest import ManifestLoader from dbt.common.exceptions import CompilationError, DbtDatabaseError import dbt.flags as flags from dbt.config.runtime import RuntimeConfig @@ -289,7 +290,13 @@ def adapter( adapter = get_adapter(runtime_config) # We only need the base macros, not macros from dependencies, and don't want # to run 'dbt deps' here. - adapter.load_macro_manifest(base_macros_only=True) + manifest = ManifestLoader.load_macros( + runtime_config, + adapter.connections.set_query_header, + base_macros_only=True, + ) + + adapter.set_macro_resolver(manifest) yield adapter adapter.cleanup_connections() reset_adapters() @@ -450,6 +457,14 @@ def create_test_schema(self, schema_name=None): # Drop the unique test schema, usually called in test cleanup def drop_test_schema(self): + if self.adapter.get_macro_resolver() is None: + manifest = ManifestLoader.load_macros( + self.adapter.config, + self.adapter.connections.set_query_header, + base_macros_only=True, + ) + self.adapter.set_macro_resolver(manifest) + with get_connection(self.adapter): for schema_name in self.created_schemas: relation = self.adapter.Relation.create(database=self.database, schema=schema_name) diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index 30ccc2a2104..31696f7a9cb 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -428,9 +428,9 @@ def _mock_state_check(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config, self.mp_context) - self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config) + self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config)) self.adapter.connections.query_header = MacroQueryStringSetter( - self.config, self.adapter._macro_manifest_lazy + self.config, self.adapter.get_macro_resolver() ) self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")