Skip to content

Commit

Permalink
Remove manifest from catalog and connection method signatures (#9242)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Dec 7, 2023
1 parent f1c2f06 commit 26ddaaf
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 57 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231205-235830.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: remove manifest from adapter.set_relations_cache signature
time: 2023-12-05T23:58:30.920144+09:00
custom:
Author: michelleark
Issue: "9217"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231206-000343.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: ' remove manifest from adapter catalog method signatures'
time: 2023-12-06T00:03:43.824252+09:00
custom:
Author: michelleark
Issue: "9218"
95 changes: 48 additions & 47 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from datetime import datetime
from enum import Enum
import time
from itertools import chain
from typing import (
Any,
Callable,
Expand All @@ -19,6 +18,7 @@
Type,
TypedDict,
Union,
FrozenSet,
)
from multiprocessing.context import SpawnContext

Expand Down Expand Up @@ -75,6 +75,7 @@
)
from dbt.common.utils import filter_null_values, executor, cast_to_str, AttrDict

from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.base.connections import Connection, AdapterResponse, BaseConnectionManager
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import (
Expand Down Expand Up @@ -109,11 +110,13 @@ def _expect_row_value(key: str, row: agate.Row):
return row[key]


def _catalog_filter_schemas(manifest: Manifest) -> Callable[[agate.Row], bool]:
def _catalog_filter_schemas(
used_schemas: FrozenSet[Tuple[str, str]]
) -> Callable[[agate.Row], bool]:
"""Return a function that takes a row and decides if the row should be
included in the catalog output.
"""
schemas = frozenset((d.lower(), s.lower()) for d, s in manifest.get_used_schemas())
schemas = frozenset((d.lower(), s.lower()) for d, s in used_schemas)

def test(row: agate.Row) -> bool:
table_database = _expect_row_value("table_database", row)
Expand Down Expand Up @@ -417,18 +420,16 @@ def _schema_is_cached(self, database: Optional[str], schema: str) -> bool:
else:
return True

def _get_cache_schemas(self, manifest: Manifest) -> Set[BaseRelation]:
def _get_cache_schemas(self, relation_configs: Iterable[RelationConfig]) -> Set[BaseRelation]:
"""Get the set of schema relations that the cache logic needs to
populate. This means only executable nodes are included.
populate.
"""
# the cache only cares about executable nodes
return {
self.Relation.create_from(self.config, node).without_identifier() # type: ignore[arg-type]
for node in manifest.nodes.values()
if (node.is_relational and not node.is_ephemeral_model and not node.is_external_node)
self.Relation.create_from(quoting=self.config, relation_config=relation_config)
for relation_config in relation_configs
}

def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
def _get_catalog_schemas(self, relation_configs: Iterable[RelationConfig]) -> SchemaSearchMap:
"""Get a mapping of each node's "information_schema" relations to a
set of all schemas expected in that information_schema.
Expand All @@ -438,7 +439,7 @@ def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
lowercase strings.
"""
info_schema_name_map = SchemaSearchMap()
relations = self._get_catalog_relations(manifest)
relations = self._get_catalog_relations(relation_configs)
for relation in relations:
info_schema_name_map.add(relation)
# result is a map whose keys are information_schema Relations without
Expand All @@ -459,28 +460,25 @@ def _get_catalog_relations_by_info_schema(

return relations_by_info_schema

def _get_catalog_relations(self, manifest: Manifest) -> List[BaseRelation]:

nodes = chain(
[
node
for node in manifest.nodes.values()
if (node.is_relational and not node.is_ephemeral_model)
],
manifest.sources.values(),
)

relations = [self.Relation.create_from(self.config, n) for n in nodes] # type: ignore[arg-type]
def _get_catalog_relations(
self, relation_configs: Iterable[RelationConfig]
) -> List[BaseRelation]:
relations = [
self.Relation.create_from(quoting=self.config, relation_config=relation_config)
for relation_config in relation_configs
]
return relations

def _relations_cache_for_schemas(
self, manifest: Manifest, cache_schemas: Optional[Set[BaseRelation]] = None
self,
relation_configs: Iterable[RelationConfig],
cache_schemas: Optional[Set[BaseRelation]] = None,
) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not cache_schemas:
cache_schemas = self._get_cache_schemas(manifest)
cache_schemas = self._get_cache_schemas(relation_configs)
with executor(self.config) as tpe:
futures: List[Future[List[BaseRelation]]] = []
for cache_schema in cache_schemas:
Expand Down Expand Up @@ -509,7 +507,7 @@ def _relations_cache_for_schemas(

def set_relations_cache(
self,
manifest: Manifest,
relation_configs: Iterable[RelationConfig],
clear: bool = False,
required_schemas: Optional[Set[BaseRelation]] = None,
) -> None:
Expand All @@ -519,7 +517,7 @@ def set_relations_cache(
with self.cache.lock:
if clear:
self.cache.clear()
self._relations_cache_for_schemas(manifest, required_schemas)
self._relations_cache_for_schemas(relation_configs, required_schemas)

@available
def cache_added(self, relation: Optional[BaseRelation]) -> str:
Expand Down Expand Up @@ -1116,7 +1114,9 @@ def execute_macro(
return result

@classmethod
def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.Table:
def _catalog_filter_table(
cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]]
) -> agate.Table:
"""Filter the table as appropriate for catalog entries. Subclasses can
override this to change filtering rules on a per-adapter basis.
"""
Expand All @@ -1126,31 +1126,28 @@ def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.
table.column_names,
text_only_columns=["table_database", "table_schema", "table_name"],
)
return table.where(_catalog_filter_schemas(manifest))
return table.where(_catalog_filter_schemas(used_schemas))

def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Set[str],
manifest: Manifest,
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)

results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type]
results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
return results

def _get_one_catalog_by_relations(
self,
information_schema: InformationSchema,
relations: List[BaseRelation],
manifest: Manifest,
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:

kwargs = {
Expand All @@ -1160,16 +1157,16 @@ def _get_one_catalog_by_relations(
table = self.execute_macro(
GET_CATALOG_RELATIONS_MACRO_NAME,
kwargs=kwargs,
# pass in the full manifest, so we get any local project
# overrides
manifest=manifest,
)

results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type]
results = self._catalog_filter_table(table, used_schemas) # type: ignore[arg-type]
return results

def get_filtered_catalog(
self, manifest: Manifest, relations: Optional[Set[BaseRelation]] = None
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
relations: Optional[Set[BaseRelation]] = None,
):
catalogs: agate.Table
if (
Expand All @@ -1178,11 +1175,11 @@ def get_filtered_catalog(
or not self.supports(Capability.SchemaMetadataByRelations)
):
# Do it the traditional way. We get the full catalog.
catalogs, exceptions = self.get_catalog(manifest)
catalogs, exceptions = self.get_catalog(relation_configs, used_schemas)
else:
# Do it the new way. We try to save time by selecting information
# only for the exact set of relations we are interested in.
catalogs, exceptions = self.get_catalog_by_relations(manifest, relations)
catalogs, exceptions = self.get_catalog_by_relations(used_schemas, relations)

if relations and catalogs:
relation_map = {
Expand Down Expand Up @@ -1210,24 +1207,28 @@ def in_map(row: agate.Row):
def row_matches_relation(self, row: agate.Row, relations: Set[BaseRelation]):
pass

def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
def get_catalog(
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
) -> Tuple[agate.Table, List[Exception]]:
with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
schema_map: SchemaSearchMap = self._get_catalog_schemas(manifest)
schema_map: SchemaSearchMap = self._get_catalog_schemas(relation_configs)
for info, schemas in schema_map.items():
if len(schemas) == 0:
continue
name = ".".join([str(info.database), "information_schema"])
fut = tpe.submit_connected(
self, name, self._get_one_catalog, info, schemas, manifest
self, name, self._get_one_catalog, info, schemas, used_schemas
)
futures.append(fut)

catalogs, exceptions = catch_as_completed(futures)
return catalogs, exceptions

def get_catalog_by_relations(
self, manifest: Manifest, relations: Set[BaseRelation]
self, used_schemas: FrozenSet[Tuple[str, str]], relations: Set[BaseRelation]
) -> Tuple[agate.Table, List[Exception]]:
with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
Expand All @@ -1241,7 +1242,7 @@ def get_catalog_by_relations(
self._get_one_catalog_by_relations,
info_schema,
relations,
manifest,
used_schemas,
)
futures.append(fut)

Expand Down
13 changes: 12 additions & 1 deletion core/dbt/task/docs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple, Set, Iterable
import agate
from itertools import chain

import dbt.common.utils.formatting
from dbt.common.dataclass_schema import ValidationError
Expand Down Expand Up @@ -261,7 +262,17 @@ def run(self) -> CatalogArtifact:
}

# This generates the catalog as an agate.Table
catalog_table, exceptions = adapter.get_filtered_catalog(self.manifest, relations)
catalogable_nodes = chain(
[
node
for node in self.manifest.nodes.values()
if (node.is_relational and not node.is_ephemeral_model)
]
)
used_schemas = self.manifest.get_used_schemas()
catalog_table, exceptions = adapter.get_filtered_catalog(
catalogable_nodes, used_schemas, relations
)

catalog_data: List[PrimitiveDict] = [
dict(zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row)))
Expand Down
14 changes: 12 additions & 2 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,21 @@ def populate_adapter_cache(
if not self.args.populate_cache:
return

if self.manifest is None:
raise DbtInternalError("manifest was None in populate_adapter_cache")

start_populate_cache = time.perf_counter()
# the cache only cares about executable nodes
cachable_nodes = [
node
for node in self.manifest.nodes.values()
if (node.is_relational and not node.is_ephemeral_model and not node.is_external_node)
]

if get_flags().CACHE_SELECTED_ONLY is True:
adapter.set_relations_cache(self.manifest, required_schemas=required_schemas)
adapter.set_relations_cache(cachable_nodes, required_schemas=required_schemas)
else:
adapter.set_relations_cache(self.manifest)
adapter.set_relations_cache(cachable_nodes)
cache_populate_time = time.perf_counter() - start_populate_cache
if dbt.tracking.active_user is not None:
dbt.tracking.track_runnable_timing(
Expand Down
13 changes: 6 additions & 7 deletions tests/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,15 @@ def catalog_test(self, mock_get_relations, mock_execute, filtered=False):

mock_get_relations.return_value = relations

mock_manifest = mock.MagicMock()
mock_manifest.get_used_schemas.return_value = {("dbt", "foo"), ("dbt", "quux")}
relation_configs = []
used_schemas = {("dbt", "foo"), ("dbt", "quux")}

if filtered:
catalog, exceptions = self.adapter.get_filtered_catalog(
mock_manifest, set([relations[0], relations[3]])
relation_configs, used_schemas, set([relations[0], relations[3]])
)
else:
catalog, exceptions = self.adapter.get_catalog(mock_manifest)
catalog, exceptions = self.adapter.get_catalog(relation_configs, used_schemas)

tupled_catalog = set(map(tuple, catalog))
if filtered:
Expand Down Expand Up @@ -560,8 +560,7 @@ def test_dbname_verification_is_case_insensitive(self):

class TestPostgresFilterCatalog(unittest.TestCase):
def test__catalog_filter_table(self):
manifest = mock.MagicMock()
manifest.get_used_schemas.return_value = [["a", "B"], ["a", "1234"]]
used_schemas = [["a", "B"], ["a", "1234"]]
column_names = ["table_name", "table_database", "table_schema", "something"]
rows = [
["foo", "a", "b", "1234"], # include
Expand All @@ -571,7 +570,7 @@ def test__catalog_filter_table(self):
]
table = agate.Table(rows, column_names, agate_helper.DEFAULT_TYPE_TESTER)

result = PostgresAdapter._catalog_filter_table(table, manifest)
result = PostgresAdapter._catalog_filter_table(table, used_schemas)
assert len(result) == 3
for row in result.rows:
assert isinstance(row["table_schema"], str)
Expand Down

0 comments on commit 26ddaaf

Please sign in to comment.