From 872c0c06f75b5a6d5ce2280da0a110f649e6593a Mon Sep 17 00:00:00 2001 From: tlento Date: Wed, 15 May 2024 14:59:47 -0700 Subject: [PATCH] Make source_semantic_models property accessible from a DataflowPlanNode This change ultimately adds the source_semantic_models property to the DataflowPlan and adds a hook for enabling access to it from any arbitrary DataflowPlanNode. We currently have two use-cases for this, one in the cloud codebase that needs the semantic model inputs for a dataflow plan, and the upcoming predicate pushdown evaluation which needs the semantic model inputs for a given DataflowPlanNode. An earlier version of this change added the property directly to the DataflowPlanNode, which would satisfy both use cases above. The issue with having this property assigned directly to a DataflowPlanNode is that the property might be considered both a node-level and graph-level attribute, so it's not clear where to put the accessor. The solution we came up with for this was to allow access to a DataflowPlan DAG object built from the node, which would effectively encapsulate the subgraph represented by the node and its ancestors. Then we can access these subgraph properties through the DataflowPlan while making it clear to the caller that what they are asking for is a subgraph-level, rather than a node-level attribute. --- metricflow/dataflow/dataflow_plan.py | 43 ++++++++++++++- metricflow/dataflow/nodes/read_sql_source.py | 10 +++- .../dataflow/test_dataflow_plan.py | 55 +++++++++++++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 tests_metricflow/dataflow/test_dataflow_plan.py diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 640a50be3f..4685ddad47 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -5,14 +5,18 @@ import logging import typing from abc import ABC, abstractmethod -from typing import Generic, Optional, Sequence, Set, Type, TypeVar +from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar +import more_itertools from metricflow_semantics.dag.id_prefix import StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec from metricflow_semantics.visitor import Visitable, VisitorOutputT if typing.TYPE_CHECKING: + from dbt_semantic_interfaces.references import SemanticModelReference + from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec + from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode @@ -60,6 +64,22 @@ def parent_nodes(self) -> Sequence[DataflowPlanNode]: """Return the nodes where data for this node comes from.""" return self._parent_nodes + @property + def _input_semantic_model(self) -> Optional[SemanticModelReference]: + """Return the semantic model serving as direct input for this node, if one exists.""" + return None + + def as_plan(self) -> DataflowPlan: + """Converter method for taking an arbitrary mode and producing an associated DataflowPlan. + + This is useful for doing lookups for plan-level properties at points in the call stack where we only have + a subgraph of a complete plan. For example, the total number of nodes represented by this node and all of + its parents would be a property of a given subgraph of the DAG. Rather than doing recursive property walks + inside of each node, we make those properties of the DataflowPlan, and this node-level converter makes + such properties easily accessible. + """ + return DataflowPlan(sink_nodes=(self,)) + @abstractmethod def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: """Called when a visitor needs to visit this node.""" @@ -188,3 +208,24 @@ def __init__(self, sink_nodes: Sequence[DataflowPlanNode], plan_id: Optional[Dag @property def sink_node(self) -> DataflowPlanNode: # noqa: D102 return self._sink_nodes[0] + + def __complete_subgraph(self, node: DataflowPlanNode) -> Sequence[DataflowPlanNode]: + """Node accessor for retrieving a flattened sequence of all nodes in the subgraph upstream of the input node. + + Useful for gathering nodes for subtype-agnostic operations, such as common property access or simple counts. + """ + flattened_parent_subgraphs = tuple( + more_itertools.collapse(self.__complete_subgraph(parent_node) for parent_node in node.parent_nodes) + ) + return (node,) + flattened_parent_subgraphs + + @property + def source_semantic_models(self) -> FrozenSet[SemanticModelReference]: + """Return the complete set of source semantic models for this DataflowPlan.""" + return frozenset( + [ + node._input_semantic_model + for node in self.__complete_subgraph(self.sink_node) + if node._input_semantic_model is not None + ] + ) diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index a8285db5a2..0010a77661 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -1,12 +1,14 @@ from __future__ import annotations import textwrap -from typing import Sequence +from typing import Optional, Sequence import jinja2 +from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.visitor import VisitorOutputT +from typing_extensions import override from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor from metricflow.dataset.sql_dataset import SqlDataSet @@ -31,6 +33,12 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_source_node(self) + @override + @property + def _input_semantic_model(self) -> Optional[SemanticModelReference]: + """Return the semantic model serving as direct input for this node, if one exists.""" + return self.data_set.semantic_model_reference + @property def data_set(self) -> SqlDataSet: """Return the data set that this source represents and is passed to the child nodes.""" diff --git a/tests_metricflow/dataflow/test_dataflow_plan.py b/tests_metricflow/dataflow/test_dataflow_plan.py new file mode 100644 index 0000000000..1fba5b5662 --- /dev/null +++ b/tests_metricflow/dataflow/test_dataflow_plan.py @@ -0,0 +1,55 @@ +"""Tests for operations on dataflow plans and dataflow plan nodes.""" + +from __future__ import annotations + +from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference +from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec +from metricflow_semantics.specs.spec_classes import ( + DimensionSpec, + MetricSpec, +) + +from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder + + +def test_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a simple query plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + ) + ) + + assert dataflow_plan.source_semantic_models == frozenset( + [SemanticModelReference(semantic_model_name="bookings_source")] + ) + + +def test_multi_hop_joined_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a multi-hop join plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + dimension_specs=( + DimensionSpec( + element_name="home_state_latest", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + ), + ) + ) + + assert dataflow_plan.source_semantic_models == frozenset( + [ + SemanticModelReference(semantic_model_name="bookings_source"), + SemanticModelReference(semantic_model_name="listings_latest"), + SemanticModelReference(semantic_model_name="users_latest"), + ] + )