From 0daea9fb5a5210d52fecf11fb830eaa34fbd9cb2 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Wed, 12 Apr 2023 11:45:32 -0400 Subject: [PATCH] Create publication.py, various Publication classes, Dependency class --- core/dbt/compilation.py | 195 +++++++++--------- core/dbt/contracts/graph/manifest.py | 3 + core/dbt/contracts/publication.py | 46 +++++ core/dbt/parser/manifest.py | 93 ++++++++- tests/functional/docs/test_generate.py | 17 +- .../test_all_experimental_parser.py | 13 +- tests/functional/groups/test_access.py | 18 -- tests/functional/groups/test_publication.py | 78 +++++++ 8 files changed, 343 insertions(+), 120 deletions(-) create mode 100644 core/dbt/contracts/publication.py create mode 100644 tests/functional/groups/test_publication.py diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index e5503af462e..3ee706e249f 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -161,6 +161,94 @@ def write_graph(self, outfile: str, manifest: Manifest): with open(outfile, "wb") as outfh: pickle.dump(out_graph, outfh, protocol=pickle.HIGHEST_PROTOCOL) + def link_node(self, node: GraphMemberNode, manifest: Manifest): + self.add_node(node.unique_id) + + for dependency in node.depends_on_nodes: + if dependency in manifest.nodes: + self.dependency(node.unique_id, (manifest.nodes[dependency].unique_id)) + elif dependency in manifest.sources: + self.dependency(node.unique_id, (manifest.sources[dependency].unique_id)) + elif dependency in manifest.metrics: + self.dependency(node.unique_id, (manifest.metrics[dependency].unique_id)) + else: + raise GraphDependencyNotFoundError(node, dependency) + + def link_graph(self, manifest: Manifest, add_test_edges: bool = False): + for source in manifest.sources.values(): + self.add_node(source.unique_id) + for node in manifest.nodes.values(): + self.link_node(node, manifest) + for exposure in manifest.exposures.values(): + self.link_node(exposure, manifest) + for metric in manifest.metrics.values(): + self.link_node(metric, manifest) + + cycle = self.find_cycles() + + if cycle: + raise RuntimeError("Found a cycle: {}".format(cycle)) + + if add_test_edges: + manifest.build_parent_and_child_maps() + self.add_test_edges(manifest) + + def add_test_edges(self, manifest: Manifest) -> None: + """This method adds additional edges to the DAG. For a given non-test + executable node, add an edge from an upstream test to the given node if + the set of nodes the test depends on is a subset of the upstream nodes + for the given node.""" + + # Given a graph: + # model1 --> model2 --> model3 + # | | + # | \/ + # \/ test 2 + # test1 + # + # Produce the following graph: + # model1 --> model2 --> model3 + # | /\ | /\ /\ + # | | \/ | | + # \/ | test2 ----| | + # test1 ----|---------------| + + for node_id in self.graph: + # If node is executable (in manifest.nodes) and does _not_ + # represent a test, continue. + if ( + node_id in manifest.nodes + and manifest.nodes[node_id].resource_type != NodeType.Test + ): + # Get *everything* upstream of the node + all_upstream_nodes = nx.traversal.bfs_tree(self.graph, node_id, reverse=True) + # Get the set of upstream nodes not including the current node. + upstream_nodes = set([n for n in all_upstream_nodes if n != node_id]) + + # Get all tests that depend on any upstream nodes. + upstream_tests = [] + for upstream_node in upstream_nodes: + upstream_tests += _get_tests_for_node(manifest, upstream_node) + + for upstream_test in upstream_tests: + # Get the set of all nodes that the test depends on + # including the upstream_node itself. This is necessary + # because tests can depend on multiple nodes (ex: + # relationship tests). Test nodes do not distinguish + # between what node the test is "testing" and what + # node(s) it depends on. + test_depends_on = set(manifest.nodes[upstream_test].depends_on_nodes) + + # If the set of nodes that an upstream test depends on + # is a subset of all upstream nodes of the current node, + # add an edge from the upstream test to the current node. + if test_depends_on.issubset(upstream_nodes): + self.graph.add_edge(upstream_test, node_id) + + def get_graph(self, manifest: Manifest) -> Graph: + self.link_graph(manifest) + return Graph(self.graph) + class Compiler: def __init__(self, config): @@ -385,104 +473,13 @@ def _compile_code( return node - def write_graph_file(self, linker: Linker, manifest: Manifest): - filename = graph_file_name - graph_path = os.path.join(self.config.target_path, filename) - flags = get_flags() - if flags.WRITE_JSON: - linker.write_graph(graph_path, manifest) - - def link_node(self, linker: Linker, node: GraphMemberNode, manifest: Manifest): - linker.add_node(node.unique_id) - - for dependency in node.depends_on_nodes: - if dependency in manifest.nodes: - linker.dependency(node.unique_id, (manifest.nodes[dependency].unique_id)) - elif dependency in manifest.sources: - linker.dependency(node.unique_id, (manifest.sources[dependency].unique_id)) - elif dependency in manifest.metrics: - linker.dependency(node.unique_id, (manifest.metrics[dependency].unique_id)) - else: - raise GraphDependencyNotFoundError(node, dependency) - - def link_graph(self, linker: Linker, manifest: Manifest, add_test_edges: bool = False): - for source in manifest.sources.values(): - linker.add_node(source.unique_id) - for node in manifest.nodes.values(): - self.link_node(linker, node, manifest) - for exposure in manifest.exposures.values(): - self.link_node(linker, exposure, manifest) - for metric in manifest.metrics.values(): - self.link_node(linker, metric, manifest) - - cycle = linker.find_cycles() - - if cycle: - raise RuntimeError("Found a cycle: {}".format(cycle)) - - if add_test_edges: - manifest.build_parent_and_child_maps() - self.add_test_edges(linker, manifest) - - def add_test_edges(self, linker: Linker, manifest: Manifest) -> None: - """This method adds additional edges to the DAG. For a given non-test - executable node, add an edge from an upstream test to the given node if - the set of nodes the test depends on is a subset of the upstream nodes - for the given node.""" - - # Given a graph: - # model1 --> model2 --> model3 - # | | - # | \/ - # \/ test 2 - # test1 - # - # Produce the following graph: - # model1 --> model2 --> model3 - # | /\ | /\ /\ - # | | \/ | | - # \/ | test2 ----| | - # test1 ----|---------------| - - for node_id in linker.graph: - # If node is executable (in manifest.nodes) and does _not_ - # represent a test, continue. - if ( - node_id in manifest.nodes - and manifest.nodes[node_id].resource_type != NodeType.Test - ): - # Get *everything* upstream of the node - all_upstream_nodes = nx.traversal.bfs_tree(linker.graph, node_id, reverse=True) - # Get the set of upstream nodes not including the current node. - upstream_nodes = set([n for n in all_upstream_nodes if n != node_id]) - - # Get all tests that depend on any upstream nodes. - upstream_tests = [] - for upstream_node in upstream_nodes: - upstream_tests += _get_tests_for_node(manifest, upstream_node) - - for upstream_test in upstream_tests: - # Get the set of all nodes that the test depends on - # including the upstream_node itself. This is necessary - # because tests can depend on multiple nodes (ex: - # relationship tests). Test nodes do not distinguish - # between what node the test is "testing" and what - # node(s) it depends on. - test_depends_on = set(manifest.nodes[upstream_test].depends_on_nodes) - - # If the set of nodes that an upstream test depends on - # is a subset of all upstream nodes of the current node, - # add an edge from the upstream test to the current node. - if test_depends_on.issubset(upstream_nodes): - linker.graph.add_edge(upstream_test, node_id) - + # This method doesn't actually "compile" any of the nodes. That is done by the + # "compile_node" method. This creates a Linker and builds the networkx graph, + # writes out the graph.gpickle file, and prints the stats, returning a Graph object. def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph: self.initialize() linker = Linker() - - self.link_graph(linker, manifest, add_test_edges) - - stats = _generate_stats(manifest) + linker.link_graph(manifest, add_test_edges) if write: self.write_graph_file(linker, manifest) @@ -492,10 +489,18 @@ def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph self.config.args.__class__ == argparse.Namespace and self.config.args.cls == list_task.ListTask ): + stats = _generate_stats(manifest) print_compile_stats(stats) return Graph(linker.graph) + def write_graph_file(self, linker: Linker, manifest: Manifest): + filename = graph_file_name + graph_path = os.path.join(self.config.target_path, filename) + flags = get_flags() + if flags.WRITE_JSON: + linker.write_graph(graph_path, manifest) + # writes the "compiled_code" into the target/compiled directory def _write_node(self, node: ManifestSQLNode) -> ManifestSQLNode: if not node.extra_ctes_injected or node.resource_type in ( diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 4eb80090adb..8118f5073e6 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -22,6 +22,8 @@ from typing_extensions import Protocol from uuid import UUID +from dbt.contracts.publication import Dependencies + from dbt.contracts.graph.nodes import ( Macro, Documentation, @@ -633,6 +635,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin): source_patches: MutableMapping[SourceKey, SourcePatch] = field(default_factory=dict) disabled: MutableMapping[str, List[GraphMemberNode]] = field(default_factory=dict) env_vars: MutableMapping[str, str] = field(default_factory=dict) + dependencies: Optional[Dependencies] = None _doc_lookup: Optional[DocLookup] = field( default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} diff --git a/core/dbt/contracts/publication.py b/core/dbt/contracts/publication.py new file mode 100644 index 00000000000..6f72ffb564d --- /dev/null +++ b/core/dbt/contracts/publication.py @@ -0,0 +1,46 @@ +from typing import Optional, List, Dict, Any +from dbt.dataclass_schema import dbtClassMixin + +from dataclasses import dataclass, field + +from dbt.contracts.util import BaseArtifactMetadata, ArtifactMixin, schema_version + + +@dataclass +class DependentProjects(dbtClassMixin): + name: str + environment: str + + +@dataclass +class Dependencies(dbtClassMixin): + projects: list[DependentProjects] = field(default_factory=list) + + +@dataclass +class PublicationMetadata(BaseArtifactMetadata): + dbt_schema_version: str = field(default_factory=lambda: str(Publication.dbt_schema_version)) + adapter_type: Optional[str] = None + quoting: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PublicModel(dbtClassMixin): + relation_name: str + latest: bool = False # not implemented yet + # list of model unique_ids + public_dependencies: List[str] = field(default_factory=list) + + +@dataclass +class PublicationMandatory: + project_name: str + + +@dataclass +@schema_version("publication", 1) +class Publication(ArtifactMixin, PublicationMandatory): + public_models: Dict[str, PublicModel] = field(default_factory=dict) + metadata: PublicationMetadata = field(default_factory=PublicationMetadata) + # list of project name strings + dependencies: List[str] = field(default_factory=list) diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 227d434bff9..1aa82a52d90 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -22,6 +22,7 @@ get_adapter_package_names, ) from dbt.helper_types import PathSet +from dbt.clients.yaml_helper import load_yaml_text from dbt.events.functions import fire_event, get_invocation_id, warn_or_error from dbt.events.types import ( PartialParsingErrorProcessingFile, @@ -40,7 +41,14 @@ from dbt.node_types import NodeType, AccessType from dbt.clients.jinja import get_rendered, MacroStack from dbt.clients.jinja_static import statically_extract_macro_calls -from dbt.clients.system import make_directory, path_exists, read_json, write_file +from dbt.clients.system import ( + make_directory, + path_exists, + read_json, + write_file, + resolve_path_from_base, + load_file_contents, +) from dbt.config import Project, RuntimeConfig from dbt.context.docs import generate_runtime_docs_context from dbt.context.macro_resolver import MacroResolver, TestMacroNamespace @@ -72,6 +80,7 @@ ) from dbt.contracts.graph.unparsed import NodeVersion from dbt.contracts.util import Writable +from dbt.contracts.publication import Publication, PublicationMetadata, PublicModel, Dependencies from dbt.exceptions import TargetNotFoundError, AmbiguousAliasError from dbt.parser.base import Parser from dbt.parser.analysis import AnalysisParser @@ -91,6 +100,7 @@ from dbt.dataclass_schema import StrEnum, dbtClassMixin MANIFEST_FILE_NAME = "manifest.json" +DEPENDENCIES_FILE_NAME = "dependencies.yml" PARTIAL_PARSE_FILE_NAME = "partial_parse.msgpack" PARSING_STATE = DbtProcessState("parsing") PERF_INFO_FILE_NAME = "perf_info.json" @@ -444,6 +454,7 @@ def load(self): # write out the fully parsed manifest self.write_manifest_for_partial_parse() + self.write_artifacts() return self.manifest @@ -599,6 +610,86 @@ def write_manifest_for_partial_parse(self): except Exception: raise + def write_artifacts(self): + # write out manifest.json + # TODO: check for overlap with parse command writing manifest + write_manifest(self.manifest, self.root_project.target_path) + + # Note: this is not the right place to do this. Just doing here for ease + # of implementation. To be moved later. + self.build_dependencies() + + # build publication metadata + metadata = PublicationMetadata( + adapter_type=self.root_project.credentials.type, + quoting=self.root_project.quoting, + ) + + # get a list of public model ids first so it can be used in constructing dependencies + public_model_ids = [] + for node in self.manifest.nodes.values(): + if node.resource_type == NodeType.Model and node.access == AccessType.Public: + public_model_ids.append(node.unique_id) + + set_of_public_unique_ids = set(public_model_ids) + + # Get the Graph object from the Linker + from dbt.compilation import Linker + + linker = Linker() + graph = linker.get_graph(self.manifest) + + public_models = {} + for unique_id in public_model_ids: + model = self.manifest.nodes[unique_id] + # public_dependencies is the intersection of all parent nodes plus public nodes + public_dependencies = [] + # parents is a set + parents = graph.select_parents({unique_id}) + public_dependencies = parents.intersection(set_of_public_unique_ids) + + public_model = PublicModel( + relation_name=model.relation_name, + latest=False, # not a node field yet + public_dependencies=list(public_dependencies), + ) + public_models[unique_id] = public_model + + # TODO: get dependencies from dependencies.yml. When is it loaded? here? + dependencies = [] + publication = Publication( + metadata=metadata, + project_name=self.root_project.project_name, + public_models=public_models, + dependencies=dependencies, + ) + # write out publication artifact _publication.json + publication_file_name = f"{self.root_project.project_name}_publication.json" + path = os.path.join(self.root_project.target_path, publication_file_name) + publication.write(path) + + def build_dependencies(self): + dependencies_filepath = resolve_path_from_base( + "dependencies.yml", self.root_project.project_root + ) + if path_exists(dependencies_filepath): + contents = load_file_contents(dependencies_filepath) + dependencies_dict = load_yaml_text(contents) + dependencies = Dependencies.from_dict(dependencies_dict) + self.manifest.dependencies = dependencies + else: + self.manifest.dependencies = None + if self.manifest.dependencies: + for project in self.manifest.dependencies.projects: + # look for a _publication.json file for every project in the 'publications' dir + publication_file_name = f"{project.name}_publication.json" + # TODO: eventually we'll implement publications_dir config + path = os.path.join("publications", publication_file_name) + if os.path.exists(path): + print(f"--- found a publication_file matching {project.name}") + else: + print(f"--- did not find a publication_file matching {project.name}") + def is_partial_parsable(self, manifest: Manifest) -> Tuple[bool, Optional[str]]: """Compare the global hashes of the read-in parse results' values to the known ones, and return if it is ok to re-use the results. diff --git a/tests/functional/docs/test_generate.py b/tests/functional/docs/test_generate.py index 566fb1a9912..2fd8c14b4b3 100644 --- a/tests/functional/docs/test_generate.py +++ b/tests/functional/docs/test_generate.py @@ -1,9 +1,18 @@ -import os +import pytest -from dbt.tests.util import run_dbt +from dbt.tests.util import run_dbt, get_manifest class TestGenerate: - def test_generate_no_manifest_on_no_compile(self, project): + @pytest.fixture(scope="class") + def models(self): + return {"my_model.sql": "select 1 as fun"} + + def test_manifest_not_compiled(self, project): run_dbt(["docs", "generate", "--no-compile"]) - assert not os.path.exists("./target/manifest.json") + # manifest.json is written out in parsing now, but it + # shouldn't be compiled because of the --no-compile flag + manifest = get_manifest(project.project_root) + model_id = "model.test.my_model" + assert model_id in manifest.nodes + assert manifest.nodes[model_id].compiled is False diff --git a/tests/functional/experimental_parser/test_all_experimental_parser.py b/tests/functional/experimental_parser/test_all_experimental_parser.py index 6a5c86b8f17..4f099de3456 100644 --- a/tests/functional/experimental_parser/test_all_experimental_parser.py +++ b/tests/functional/experimental_parser/test_all_experimental_parser.py @@ -38,12 +38,16 @@ def get_manifest(): {{ config(tags='hello', x=False) }} {{ config(tags='world', x=True) }} -select * from {{ ref('model_a') }} +select * from {{ ref('model_b') }} cross join {{ source('my_src', 'my_tbl') }} where false as boop """ +basic__model_b_sql = """ +select 1 as fun +""" + ref_macro__schema_yml = """ version: 2 @@ -93,6 +97,7 @@ class BasicExperimentalParser: def models(self): return { "model_a.sql": basic__model_a_sql, + "model_b.sql": basic__model_b_sql, "schema.yml": basic__schema_yml, } @@ -154,7 +159,7 @@ def test_experimental_parser_basic( run_dbt(["--use-experimental-parser", "parse"]) manifest = get_manifest() node = manifest.nodes["model.test.model_a"] - assert node.refs == [RefArgs(name="model_a")] + assert node.refs == [["model_b"]] assert node.sources == [["my_src", "my_tbl"]] assert node.config._extra == {"x": True} assert node.config.tags == ["hello", "world"] @@ -179,7 +184,11 @@ def test_static_parser_basic(self, project): manifest = get_manifest() node = manifest.nodes["model.test.model_a"] +<<<<<<< HEAD + assert node.refs == [["model_b"]] +======= assert node.refs == [RefArgs(name="model_a")] +>>>>>>> main assert node.sources == [["my_src", "my_tbl"]] assert node.config._extra == {"x": True} assert node.config.tags == ["hello", "world"] diff --git a/tests/functional/groups/test_access.py b/tests/functional/groups/test_access.py index 2a8902b3027..6814d2f2452 100644 --- a/tests/functional/groups/test_access.py +++ b/tests/functional/groups/test_access.py @@ -10,8 +10,6 @@ yet_another_model_sql = "select 999 as weird" schema_yml = """ -version: 2 - models: - name: my_model description: "my model" @@ -21,8 +19,6 @@ """ v2_schema_yml = """ -version: 2 - models: - name: my_model description: "my model" @@ -39,8 +35,6 @@ """ groups_yml = """ -version: 2 - groups: - name: analytics owner: @@ -52,8 +46,6 @@ v3_schema_yml = """ -version: 2 - models: - name: my_model description: "my model" @@ -67,8 +59,6 @@ """ v4_schema_yml = """ -version: 2 - models: - name: my_model description: "my model" @@ -82,8 +72,6 @@ """ simple_exposure_yml = """ -version: 2 - exposures: - name: simple_exposure label: simple exposure label @@ -95,8 +83,6 @@ """ v5_schema_yml = """ -version: 2 - models: - name: my_model description: "my model" @@ -125,8 +111,6 @@ """ people_metric_yml = """ -version: 2 - metrics: - name: number_of_people @@ -147,8 +131,6 @@ """ v2_people_metric_yml = """ -version: 2 - metrics: - name: number_of_people diff --git a/tests/functional/groups/test_publication.py b/tests/functional/groups/test_publication.py new file mode 100644 index 00000000000..0c0df331e55 --- /dev/null +++ b/tests/functional/groups/test_publication.py @@ -0,0 +1,78 @@ +import pytest + +from dbt.tests.util import run_dbt, get_artifact, write_file +from dbt.contracts.publication import Publication + + +model_one_sql = """ +select 1 as fun +""" + +model_two_sql = """ +select fun from {{ ref('model_one') }} +""" + +model_three_sql = """ +select fun from {{ ref('model_two') }} +""" + +models_yml = """ +models: + - name: model_one + description: model one + access: public + - name: model_two + description: non-public model + - name: model_three + description: model three + access: public +""" + + +class TestPublicationArtifact: + @pytest.fixture(scope="class") + def models(self): + return { + "model_one.sql": model_one_sql, + "model_two.sql": model_two_sql, + "model_three.sql": model_three_sql, + "models.yml": models_yml, + } + + def test_publication_artifact(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + + publication_dict = get_artifact(project.project_root, "target", "test_publication.json") + publication = Publication.from_dict(publication_dict) + assert publication + assert len(publication.public_models) == 2 + assert publication.public_models["model.test.model_three"].public_dependencies == [ + "model.test.model_one" + ] + + +dependencies_yml = """ +projects: + - name: finance + environment: dev + - name: marketing + environment: dev +""" + + +class TestDependenciesYml: + @pytest.fixture(scope="class") + def models(self): + return { + "model_one.sql": model_one_sql, + "model_two.sql": model_two_sql, + "model_three.sql": model_three_sql, + "models.yml": models_yml, + } + + def test_dependencies(self, project): + write_file(dependencies_yml, "dependencies.yml") + + results = run_dbt(["run"]) + assert len(results) == 3