diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 34ba99fee7..c8229e0c1f 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -2,16 +2,18 @@ from __future__ import annotations +import itertools import logging import typing from abc import ABC, abstractmethod -from typing import Generic, Optional, Sequence, Set, Type, TypeVar +from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar from metricflow_semantics.dag.id_prefix import StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId from metricflow_semantics.visitor import Visitable, VisitorOutputT if typing.TYPE_CHECKING: + from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode @@ -61,6 +63,11 @@ def parent_nodes(self) -> Sequence[DataflowPlanNode]: """Return the nodes where data for this node comes from.""" return self._parent_nodes + @property + def source_semantic_models(self) -> FrozenSet[SemanticModelReference]: + """Return the complete set of source semantic models for this node, collected recursively across all parents.""" + return frozenset(itertools.chain.from_iterable([parent.source_semantic_models for parent in self.parent_nodes])) + @abstractmethod def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: """Called when a visitor needs to visit this node.""" diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index 8c0ce2cc5d..92d4613dd7 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -1,12 +1,14 @@ from __future__ import annotations import textwrap -from typing import Sequence +from typing import FrozenSet, Sequence import jinja2 +from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.visitor import VisitorOutputT +from typing_extensions import override from metricflow.dataflow.dataflow_plan import BaseOutput, DataflowPlanNode, DataflowPlanNodeVisitor from metricflow.dataset.sql_dataset import SqlDataSet @@ -31,6 +33,15 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102 def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_source_node(self) + @override + @property + def source_semantic_models(self) -> FrozenSet[SemanticModelReference]: + return ( + frozenset([self.data_set.semantic_model_reference]) + if self.data_set.semantic_model_reference + else frozenset() + ) + @property def data_set(self) -> SqlDataSet: """Return the data set that this source represents and is passed to the child nodes.""" diff --git a/tests_metricflow/dataflow/test_dataflow_plan.py b/tests_metricflow/dataflow/test_dataflow_plan.py new file mode 100644 index 0000000000..eec55b9750 --- /dev/null +++ b/tests_metricflow/dataflow/test_dataflow_plan.py @@ -0,0 +1,57 @@ +"""Tests for operations on dataflow plans and dataflow plan nodes.""" + +from __future__ import annotations + +from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference +from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec +from metricflow_semantics.specs.spec_classes import ( + DimensionSpec, + MetricSpec, +) + +from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder + + +def test_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a simple query plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + ) + ) + + assert len(dataflow_plan.sink_nodes) == 1, "Dataflow plan should have exactly one sink node." + assert dataflow_plan.sink_nodes[0].source_semantic_models == frozenset( + [SemanticModelReference(semantic_model_name="bookings_source")] + ) + + +def test_multi_hop_joined_source_semantic_models_accessor( + dataflow_plan_builder: DataflowPlanBuilder, +) -> None: + """Tests source semantic models access for a multi-hop join plan.""" + dataflow_plan = dataflow_plan_builder.build_plan( + MetricFlowQuerySpec( + metric_specs=(MetricSpec(element_name="bookings"),), + dimension_specs=( + DimensionSpec( + element_name="home_state_latest", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + ), + ) + ) + + assert len(dataflow_plan.sink_nodes) == 1, "Dataflow plan should have exactly one sink node." + assert dataflow_plan.sink_nodes[0].source_semantic_models == frozenset( + [ + SemanticModelReference(semantic_model_name="bookings_source"), + SemanticModelReference(semantic_model_name="listings_latest"), + SemanticModelReference(semantic_model_name="users_latest"), + ] + )