Skip to content

Commit

Permalink
Add MacroResolverProtocol, remove lazy loading of manifest in adapter…
Browse files Browse the repository at this point in the history
….execute_macro (#9243)

* remove manifest from adapter.execute_macro, replace with MacroResolver + remove lazy loading

* rename to MacroResolverProtocol

* pass MacroResolverProtcol in adapter.calculate_freshness_from_metadata

* changelog entry

* fix adapter.calculate_freshness call
  • Loading branch information
MichelleArk authored Dec 7, 2023
1 parent 26ddaaf commit 4d16524
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 69 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231207-111554.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add MacroResolverProtocol, remove lazy loading of manifest in adapter.execute_macro
time: 2023-12-07T11:15:54.427818+09:00
custom:
Author: michelleark
Issue: "9244"
89 changes: 31 additions & 58 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
Expand All @@ -19,6 +18,7 @@
TypedDict,
Union,
FrozenSet,
Iterable,
)
from multiprocessing.context import SpawnContext

Expand All @@ -28,6 +28,7 @@
ConstraintType,
ModelLevelConstraint,
)
from dbt.adapters.contracts.macros import MacroResolverProtocol

import agate
import pytz
Expand Down Expand Up @@ -62,7 +63,6 @@
Integer,
)
from dbt.common.clients.jinja import CallableMacroGenerator
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.common.events.functions import fire_event, warn_or_error
from dbt.adapters.events.types import (
CacheMiss,
Expand Down Expand Up @@ -257,7 +257,20 @@ def __init__(self, config, mp_context: SpawnContext) -> None:
self.config = config
self.cache = RelationsCache(log_cache_events=config.log_cache_events)
self.connections = self.ConnectionManager(config, mp_context)
self._macro_manifest_lazy: Optional[MacroManifest] = None
self._macro_resolver: Optional[MacroResolverProtocol] = None

###
# Methods to set / access a macro resolver
###
def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None:
self._macro_resolver = macro_resolver

def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
return self._macro_resolver

def clear_macro_resolver(self) -> None:
if self._macro_resolver is not None:
self._macro_resolver = None

Check warning on line 273 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L272-L273

Added lines #L272 - L273 were not covered by tests

###
# Methods that pass through to the connection manager
Expand Down Expand Up @@ -370,39 +383,6 @@ def type(cls) -> str:
"""
return cls.ConnectionManager.TYPE

@property
def _macro_manifest(self) -> MacroManifest:
if self._macro_manifest_lazy is None:
return self.load_macro_manifest()
return self._macro_manifest_lazy

def check_macro_manifest(self) -> Optional[MacroManifest]:
"""Return the internal manifest (used for executing macros) if it's
been initialized, otherwise return None.
"""
return self._macro_manifest_lazy

def load_macro_manifest(self, base_macros_only=False) -> MacroManifest:
# base_macros_only is for the test framework
if self._macro_manifest_lazy is None:
# avoid a circular import
from dbt.parser.manifest import ManifestLoader

manifest = ManifestLoader.load_macros(
self.config,
self.connections.set_query_header,
base_macros_only=base_macros_only,
)
# TODO CT-211
self._macro_manifest_lazy = manifest # type: ignore[assignment]
# TODO CT-211
return self._macro_manifest_lazy # type: ignore[return-value]

def clear_macro_manifest(self):
if self._macro_manifest_lazy is not None:
self._macro_manifest_lazy = None

###
# Caching methods
###
def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
Expand Down Expand Up @@ -1052,11 +1032,10 @@ def convert_agate_type(cls, agate_table: agate.Table, col_idx: int) -> Optional[
def execute_macro(
self,
macro_name: str,
manifest: Optional[Manifest] = None,
macro_resolver: Optional[MacroResolverProtocol] = None,
project: Optional[str] = None,
context_override: Optional[Dict[str, Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
text_only_columns: Optional[Iterable[str]] = None,
) -> AttrDict:
"""Look macro_name up in the manifest and execute its results.
Expand All @@ -1076,13 +1055,11 @@ def execute_macro(
if context_override is None:
context_override = {}

if manifest is None:
# TODO CT-211
manifest = self._macro_manifest # type: ignore[assignment]
# TODO CT-211
macro = manifest.find_macro_by_name( # type: ignore[union-attr]
macro_name, self.config.project_name, project
)
resolver = macro_resolver or self._macro_resolver
if resolver is None:
raise DbtInternalError("macro resolver was None when calling execute_macro!")

Check warning on line 1060 in core/dbt/adapters/base/impl.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/base/impl.py#L1060

Added line #L1060 was not covered by tests

macro = resolver.find_macro_by_name(macro_name, self.config.project_name, project)
if macro is None:
if project is None:
package_name = "any package"
Expand All @@ -1102,7 +1079,7 @@ def execute_macro(
# TODO CT-211
macro=macro,
config=self.config,
manifest=manifest, # type: ignore[arg-type]
manifest=resolver, # type: ignore[arg-type]
package_name=project,
)
macro_context.update(context_override)
Expand Down Expand Up @@ -1135,10 +1112,7 @@ def _get_one_catalog(
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
)
table = self.execute_macro(GET_CATALOG_MACRO_NAME, kwargs=kwargs)

results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
return results
Expand All @@ -1154,10 +1128,7 @@ def _get_one_catalog_by_relations(
"information_schema": information_schema,
"relations": relations,
}
table = self.execute_macro(
GET_CATALOG_RELATIONS_MACRO_NAME,
kwargs=kwargs,
)
table = self.execute_macro(GET_CATALOG_RELATIONS_MACRO_NAME, kwargs=kwargs)

results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
return results
Expand Down Expand Up @@ -1258,7 +1229,7 @@ def calculate_freshness(
source: BaseRelation,
loaded_at_field: str,
filter: Optional[str],
manifest: Optional[Manifest] = None,
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
"""Calculate the freshness of sources in dbt, and return it"""
kwargs: Dict[str, Any] = {
Expand All @@ -1274,7 +1245,9 @@ def calculate_freshness(
AttrDict, # current: contains AdapterResponse + agate.Table
agate.Table, # previous: just table
]
result = self.execute_macro(FRESHNESS_MACRO_NAME, kwargs=kwargs, manifest=manifest)
result = self.execute_macro(
FRESHNESS_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
)
if isinstance(result, agate.Table):
warn_or_error(CollectFreshnessReturnSignature())
adapter_response = None
Expand Down Expand Up @@ -1304,14 +1277,14 @@ def calculate_freshness(
def calculate_freshness_from_metadata(
self,
source: BaseRelation,
manifest: Optional[Manifest] = None,
macro_resolver: Optional[MacroResolverProtocol] = None,
) -> Tuple[Optional[AdapterResponse], FreshnessResponse]:
kwargs: Dict[str, Any] = {
"information_schema": source.information_schema_only(),
"relations": [source],
}
result = self.execute_macro(
GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, manifest=manifest
GET_RELATION_LAST_MODIFIED_MACRO_NAME, kwargs=kwargs, macro_resolver=macro_resolver
)
adapter_response, table = result.response, result.table # type: ignore[attr-defined]

Expand Down
11 changes: 11 additions & 0 deletions core/dbt/adapters/contracts/macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Optional
from typing_extensions import Protocol

from dbt.common.clients.jinja import MacroProtocol


class MacroResolverProtocol(Protocol):
def find_macro_by_name(
self, name: str, root_project_name: str, package: Optional[str]
) -> Optional[MacroProtocol]:
raise NotImplementedError("find_macro_by_name not implemented")

Check warning on line 11 in core/dbt/adapters/contracts/macros.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/contracts/macros.py#L11

Added line #L11 was not covered by tests
10 changes: 10 additions & 0 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import agate

from dbt.adapters.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.adapters.contracts.macros import MacroResolverProtocol
from dbt.adapters.contracts.relation import Policy, HasQuoting, RelationConfig
from dbt.common.contracts.config.base import BaseConfig
from dbt.contracts.graph.manifest import Manifest
Expand Down Expand Up @@ -66,6 +67,15 @@ class AdapterProtocol( # type: ignore[misc]
def __init__(self, config: AdapterRequiredConfig) -> None:
...

def set_macro_resolver(self, macro_resolver: MacroResolverProtocol) -> None:
...

Check warning on line 71 in core/dbt/adapters/protocol.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/protocol.py#L71

Added line #L71 was not covered by tests

def get_macro_resolver(self) -> Optional[MacroResolverProtocol]:
...

Check warning on line 74 in core/dbt/adapters/protocol.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/protocol.py#L74

Added line #L74 was not covered by tests

def clear_macro_resolver(self) -> None:
...

Check warning on line 77 in core/dbt/adapters/protocol.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/adapters/protocol.py#L77

Added line #L77 was not covered by tests

@classmethod
def type(cls) -> str:
pass
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Protocol

from dbt.adapters.base.column import Column
from dbt.common.clients.jinja import MacroProtocol
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
Expand Down Expand Up @@ -1363,7 +1364,7 @@ class MacroContext(ProviderContext):

def __init__(
self,
model: Macro,
model: MacroProtocol,
config: RuntimeConfig,
manifest: Manifest,
provider: Provider,
Expand Down Expand Up @@ -1520,7 +1521,7 @@ def generate_runtime_model_context(


def generate_runtime_macro_context(
macro: Macro,
macro: MacroProtocol,
config: RuntimeConfig,
manifest: Manifest,
package_name: Optional[str],
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def get_full_manifest(
# the config and adapter may be persistent.
if reset:
config.clear_dependencies()
adapter.clear_macro_manifest()
adapter.clear_macro_resolver()

Check warning on line 289 in core/dbt/parser/manifest.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/manifest.py#L289

Added line #L289 was not covered by tests
macro_hook = adapter.connections.set_query_header

flags = get_flags()
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def build_manifest_state_check(self):

def save_macros_to_adapter(self, adapter):
macro_manifest = MacroManifest(self.manifest.macros)
adapter._macro_manifest_lazy = macro_manifest
adapter.set_macro_resolver(macro_manifest)
# This executes the callable macro_hook and sets the
# query headers
self.macro_hook(macro_manifest)
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/task/freshness.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def execute(self, compiled_node, manifest):
relation,
compiled_node.loaded_at_field,
compiled_node.freshness.filter,
manifest=manifest,
macro_resolver=manifest,
)

status = compiled_node.freshness.status(freshness["age"])
Expand All @@ -126,7 +126,7 @@ def execute(self, compiled_node, manifest):

adapter_response, freshness = self.adapter.calculate_freshness_from_metadata(
relation,
manifest=manifest,
macro_resolver=manifest,
)

status = compiled_node.freshness.status(freshness["age"])
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _run_unsafe(self, package_name, macro_name) -> agate.Table:
with adapter.connection_named("macro_{}".format(macro_name)):
adapter.clear_transaction()
res = adapter.execute_macro(
macro_name, project=package_name, kwargs=macro_kwargs, manifest=self.manifest
macro_name, project=package_name, kwargs=macro_kwargs, macro_resolver=self.manifest
)

return res
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def execute(self, compiled_node, manifest):
model_context = generate_runtime_model_context(compiled_node, self.config, manifest)
compiled_node.compiled_code = self.adapter.execute_macro(
macro_name="get_show_sql",
manifest=manifest,
macro_resolver=manifest,
context_override=model_context,
kwargs={
"compiled_code": model_context["compiled_code"],
Expand Down
17 changes: 16 additions & 1 deletion core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import yaml

from dbt.parser.manifest import ManifestLoader
from dbt.common.exceptions import CompilationError, DbtDatabaseError
import dbt.flags as flags
from dbt.config.runtime import RuntimeConfig
Expand Down Expand Up @@ -289,7 +290,13 @@ def adapter(
adapter = get_adapter(runtime_config)
# We only need the base macros, not macros from dependencies, and don't want
# to run 'dbt deps' here.
adapter.load_macro_manifest(base_macros_only=True)
manifest = ManifestLoader.load_macros(
runtime_config,
adapter.connections.set_query_header,
base_macros_only=True,
)

adapter.set_macro_resolver(manifest)
yield adapter
adapter.cleanup_connections()
reset_adapters()
Expand Down Expand Up @@ -450,6 +457,14 @@ def create_test_schema(self, schema_name=None):

# Drop the unique test schema, usually called in test cleanup
def drop_test_schema(self):
if self.adapter.get_macro_resolver() is None:
manifest = ManifestLoader.load_macros(
self.adapter.config,
self.adapter.connections.set_query_header,
base_macros_only=True,
)
self.adapter.set_macro_resolver(manifest)

with get_connection(self.adapter):
for schema_name in self.created_schemas:
relation = self.adapter.Relation.create(database=self.database, schema=schema_name)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ def _mock_state_check(self):

self.psycopg2.connect.return_value = self.handle
self.adapter = PostgresAdapter(self.config, self.mp_context)
self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config))
self.adapter.connections.query_header = MacroQueryStringSetter(
self.config, self.adapter._macro_manifest_lazy
self.config, self.adapter.get_macro_resolver()
)

self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add")
Expand Down

0 comments on commit 4d16524

Please sign in to comment.