diff --git a/.changes/unreleased/Under the Hood-20240516-144603.yaml b/.changes/unreleased/Under the Hood-20240516-144603.yaml new file mode 100644 index 0000000000..c1f49c763e --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240516-144603.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Make source semantic models available from DataflowPlanNode instances +time: 2024-05-16T14:46:03.707367-07:00 +custom: + Author: tlento + Issue: "1218" diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index bfe85ae8bc..ceea1ddef9 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -89,6 +89,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): VALUES_GROUP_BY_ITEM_RESOLUTION_NODE = "vr" DATAFLOW_PLAN_PREFIX = "dfp" + DATAFLOW_PLAN_SUBGRAPH_PREFIX = "dfpsub" OPTIMIZED_DATAFLOW_PLAN_PREFIX = "dfpo" SQL_QUERY_PLAN_PREFIX = "sqp" EXEC_PLAN_PREFIX = "ep" diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 640a50be3f..e110df0719 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -5,14 +5,18 @@ import logging import typing from abc import ABC, abstractmethod -from typing import Generic, Optional, Sequence, Set, Type, TypeVar +from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar +import more_itertools from metricflow_semantics.dag.id_prefix import StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec from metricflow_semantics.visitor import Visitable, VisitorOutputT if typing.TYPE_CHECKING: + from dbt_semantic_interfaces.references import SemanticModelReference + from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec + from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode @@ -60,6 +64,24 @@ def parent_nodes(self) -> Sequence[DataflowPlanNode]: """Return the nodes where data for this node comes from.""" return self._parent_nodes + @property + def _input_semantic_model(self) -> Optional[SemanticModelReference]: + """Return the semantic model serving as direct input for this node, if one exists.""" + return None + + def as_plan(self) -> DataflowPlan: + """Converter method for taking an arbitrary mode and producing an associated DataflowPlan. + + This is useful for doing lookups for plan-level properties at points in the call stack where we only have + a subgraph of a complete plan. For example, the total number of nodes represented by this node and all of + its parents would be a property of a given subgraph of the DAG. Rather than doing recursive property walks + inside of each node, we make those properties of the DataflowPlan, and this node-level converter makes + such properties easily accessible. + """ + return DataflowPlan( + sink_nodes=(self,), plan_id=DagId.from_id_prefix(id_prefix=StaticIdPrefix.DATAFLOW_PLAN_SUBGRAPH_PREFIX) + ) + @abstractmethod def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: """Called when a visitor needs to visit this node.""" @@ -188,3 +210,27 @@ def __init__(self, sink_nodes: Sequence[DataflowPlanNode], plan_id: Optional[Dag @property def sink_node(self) -> DataflowPlanNode: # noqa: D102 return self._sink_nodes[0] + + @staticmethod + def __all_nodes_in_subgraph(node: DataflowPlanNode) -> Sequence[DataflowPlanNode]: + """Node accessor for retrieving a flattened sequence of all nodes in the subgraph upstream of the input node. + + Useful for gathering nodes for subtype-agnostic operations, such as common property access or simple counts. + """ + flattened_parent_subgraphs = tuple( + more_itertools.collapse( + DataflowPlan.__all_nodes_in_subgraph(parent_node) for parent_node in node.parent_nodes + ) + ) + return (node,) + flattened_parent_subgraphs + + @property + def source_semantic_models(self) -> FrozenSet[SemanticModelReference]: + """Return the complete set of source semantic models for this DataflowPlan.""" + return frozenset( + [ + node._input_semantic_model + for node in DataflowPlan.__all_nodes_in_subgraph(self.sink_node) + if node._input_semantic_model is not None + ] + ) diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index a8285db5a2..0010a77661 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -1,12 +1,14 @@ from __future__ import annotations import textwrap -from typing import Sequence +from typing import Optional, Sequence import jinja2 +from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.visitor import VisitorOutputT +from typing_extensions import override from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor from metricflow.dataset.sql_dataset import SqlDataSet @@ -31,6 +33,12 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_source_node(self) + @override + @property + def _input_semantic_model(self) -> Optional[SemanticModelReference]: + """Return the semantic model serving as direct input for this node, if one exists.""" + return self.data_set.semantic_model_reference + @property def data_set(self) -> SqlDataSet: """Return the data set that this source represents and is passed to the child nodes.""" diff --git a/tests_metricflow/dataflow/test_dataflow_plan.py b/tests_metricflow/dataflow/test_dataflow_plan.py new file mode 100644 index 0000000000..1fba5b5662 --- /dev/null +++ b/tests_metricflow/dataflow/test_dataflow_plan.py @@ -0,0 +1,55 @@ +"""Tests for operations on dataflow plans and dataflow plan nodes.""" + +from __future__ import annotations + +from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference +from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec +from metricflow_semantics.specs.spec_classes import ( + DimensionSpec, + MetricSpec, +) + +from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder + + +def test_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a simple query plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + ) + ) + + assert dataflow_plan.source_semantic_models == frozenset( + [SemanticModelReference(semantic_model_name="bookings_source")] + ) + + +def test_multi_hop_joined_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a multi-hop join plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + dimension_specs=( + DimensionSpec( + element_name="home_state_latest", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + ), + ) + ) + + assert dataflow_plan.source_semantic_models == frozenset( + [ + SemanticModelReference(semantic_model_name="bookings_source"), + SemanticModelReference(semantic_model_name="listings_latest"), + SemanticModelReference(semantic_model_name="users_latest"), + ] + )