diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 640a50be3f..4685ddad47 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -5,14 +5,18 @@ import logging import typing from abc import ABC, abstractmethod -from typing import Generic, Optional, Sequence, Set, Type, TypeVar +from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar +import more_itertools from metricflow_semantics.dag.id_prefix import StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec from metricflow_semantics.visitor import Visitable, VisitorOutputT if typing.TYPE_CHECKING: + from dbt_semantic_interfaces.references import SemanticModelReference + from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec + from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode @@ -60,6 +64,22 @@ def parent_nodes(self) -> Sequence[DataflowPlanNode]: """Return the nodes where data for this node comes from.""" return self._parent_nodes + @property + def _input_semantic_model(self) -> Optional[SemanticModelReference]: + """Return the semantic model serving as direct input for this node, if one exists.""" + return None + + def as_plan(self) -> DataflowPlan: + """Converter method for taking an arbitrary mode and producing an associated DataflowPlan. + + This is useful for doing lookups for plan-level properties at points in the call stack where we only have + a subgraph of a complete plan. For example, the total number of nodes represented by this node and all of + its parents would be a property of a given subgraph of the DAG. Rather than doing recursive property walks + inside of each node, we make those properties of the DataflowPlan, and this node-level converter makes + such properties easily accessible. + """ + return DataflowPlan(sink_nodes=(self,)) + @abstractmethod def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: """Called when a visitor needs to visit this node.""" @@ -188,3 +208,24 @@ def __init__(self, sink_nodes: Sequence[DataflowPlanNode], plan_id: Optional[Dag @property def sink_node(self) -> DataflowPlanNode: # noqa: D102 return self._sink_nodes[0] + + def __complete_subgraph(self, node: DataflowPlanNode) -> Sequence[DataflowPlanNode]: + """Node accessor for retrieving a flattened sequence of all nodes in the subgraph upstream of the input node. + + Useful for gathering nodes for subtype-agnostic operations, such as common property access or simple counts. + """ + flattened_parent_subgraphs = tuple( + more_itertools.collapse(self.__complete_subgraph(parent_node) for parent_node in node.parent_nodes) + ) + return (node,) + flattened_parent_subgraphs + + @property + def source_semantic_models(self) -> FrozenSet[SemanticModelReference]: + """Return the complete set of source semantic models for this DataflowPlan.""" + return frozenset( + [ + node._input_semantic_model + for node in self.__complete_subgraph(self.sink_node) + if node._input_semantic_model is not None + ] + ) diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index a8285db5a2..0010a77661 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -1,12 +1,14 @@ from __future__ import annotations import textwrap -from typing import Sequence +from typing import Optional, Sequence import jinja2 +from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.visitor import VisitorOutputT +from typing_extensions import override from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor from metricflow.dataset.sql_dataset import SqlDataSet @@ -31,6 +33,12 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_source_node(self) + @override + @property + def _input_semantic_model(self) -> Optional[SemanticModelReference]: + """Return the semantic model serving as direct input for this node, if one exists.""" + return self.data_set.semantic_model_reference + @property def data_set(self) -> SqlDataSet: """Return the data set that this source represents and is passed to the child nodes.""" diff --git a/tests_metricflow/dataflow/test_dataflow_plan.py b/tests_metricflow/dataflow/test_dataflow_plan.py new file mode 100644 index 0000000000..1fba5b5662 --- /dev/null +++ b/tests_metricflow/dataflow/test_dataflow_plan.py @@ -0,0 +1,55 @@ +"""Tests for operations on dataflow plans and dataflow plan nodes.""" + +from __future__ import annotations + +from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference +from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec +from metricflow_semantics.specs.spec_classes import ( + DimensionSpec, + MetricSpec, +) + +from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder + + +def test_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a simple query plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + ) + ) + + assert dataflow_plan.source_semantic_models == frozenset( + [SemanticModelReference(semantic_model_name="bookings_source")] + ) + + +def test_multi_hop_joined_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a multi-hop join plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + dimension_specs=( + DimensionSpec( + element_name="home_state_latest", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + ), + ) + ) + + assert dataflow_plan.source_semantic_models == frozenset( + [ + SemanticModelReference(semantic_model_name="bookings_source"), + SemanticModelReference(semantic_model_name="listings_latest"), + SemanticModelReference(semantic_model_name="users_latest"), + ] + )