Skip to content

Commit

Permalink
Make source_semantic_models property accessible from a DataflowPlanNode
Browse files Browse the repository at this point in the history
This change ultimately adds the source_semantic_models property to
the DataflowPlan and adds a hook for enabling access to it from any
arbitrary DataflowPlanNode.

We currently have two use-cases for this, one in the cloud codebase
that needs the semantic model inputs for a dataflow plan, and the
upcoming predicate pushdown evaluation which needs the semantic model
inputs for a given DataflowPlanNode.

An earlier version of this change added the property directly to the
DataflowPlanNode, which would satisfy both use cases above. The issue
with having this property assigned directly to a DataflowPlanNode
is that the property might be considered both a node-level and graph-level
attribute, so it's not clear where to put the accessor.

The solution we came up with for this was to allow access to a DataflowPlan
DAG object built from the node, which would effectively encapsulate the
subgraph represented by the node and its ancestors. Then we can access
these subgraph properties through the DataflowPlan while making it clear
to the caller that what they are asking for is a subgraph-level, rather
than a node-level attribute.
  • Loading branch information
tlento committed May 16, 2024
1 parent 8d77d33 commit 872c0c0
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
43 changes: 42 additions & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import logging
import typing
from abc import ABC, abstractmethod
from typing import Generic, Optional, Sequence, Set, Type, TypeVar
from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar

import more_itertools
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec
from metricflow_semantics.visitor import Visitable, VisitorOutputT

if typing.TYPE_CHECKING:
from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec

from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
Expand Down Expand Up @@ -60,6 +64,22 @@ def parent_nodes(self) -> Sequence[DataflowPlanNode]:
"""Return the nodes where data for this node comes from."""
return self._parent_nodes

@property
def _input_semantic_model(self) -> Optional[SemanticModelReference]:
"""Return the semantic model serving as direct input for this node, if one exists."""
return None

def as_plan(self) -> DataflowPlan:
"""Converter method for taking an arbitrary mode and producing an associated DataflowPlan.
This is useful for doing lookups for plan-level properties at points in the call stack where we only have
a subgraph of a complete plan. For example, the total number of nodes represented by this node and all of
its parents would be a property of a given subgraph of the DAG. Rather than doing recursive property walks
inside of each node, we make those properties of the DataflowPlan, and this node-level converter makes
such properties easily accessible.
"""
return DataflowPlan(sink_nodes=(self,))

@abstractmethod
def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
"""Called when a visitor needs to visit this node."""
Expand Down Expand Up @@ -188,3 +208,24 @@ def __init__(self, sink_nodes: Sequence[DataflowPlanNode], plan_id: Optional[Dag
@property
def sink_node(self) -> DataflowPlanNode: # noqa: D102
return self._sink_nodes[0]

def __complete_subgraph(self, node: DataflowPlanNode) -> Sequence[DataflowPlanNode]:
"""Node accessor for retrieving a flattened sequence of all nodes in the subgraph upstream of the input node.
Useful for gathering nodes for subtype-agnostic operations, such as common property access or simple counts.
"""
flattened_parent_subgraphs = tuple(
more_itertools.collapse(self.__complete_subgraph(parent_node) for parent_node in node.parent_nodes)
)
return (node,) + flattened_parent_subgraphs

@property
def source_semantic_models(self) -> FrozenSet[SemanticModelReference]:
"""Return the complete set of source semantic models for this DataflowPlan."""
return frozenset(
[
node._input_semantic_model
for node in self.__complete_subgraph(self.sink_node)
if node._input_semantic_model is not None
]
)
10 changes: 9 additions & 1 deletion metricflow/dataflow/nodes/read_sql_source.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import textwrap
from typing import Sequence
from typing import Optional, Sequence

import jinja2
from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT
from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataset.sql_dataset import SqlDataSet
Expand All @@ -31,6 +33,12 @@ def id_prefix(cls) -> IdPrefix: # noqa: D102
def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_source_node(self)

@override
@property
def _input_semantic_model(self) -> Optional[SemanticModelReference]:
"""Return the semantic model serving as direct input for this node, if one exists."""
return self.data_set.semantic_model_reference

@property
def data_set(self) -> SqlDataSet:
"""Return the data set that this source represents and is passed to the child nodes."""
Expand Down
55 changes: 55 additions & 0 deletions tests_metricflow/dataflow/test_dataflow_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Tests for operations on dataflow plans and dataflow plan nodes."""

from __future__ import annotations

from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.spec_classes import (
DimensionSpec,
MetricSpec,
)

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder


def test_source_semantic_models_accessor(
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests source semantic models access for a simple query plan."""
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings"),),
)
)

assert dataflow_plan.source_semantic_models == frozenset(
[SemanticModelReference(semantic_model_name="bookings_source")]
)


def test_multi_hop_joined_source_semantic_models_accessor(
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests source semantic models access for a multi-hop join plan."""
dataflow_plan = dataflow_plan_builder.build_plan(
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name="bookings"),),
dimension_specs=(
DimensionSpec(
element_name="home_state_latest",
entity_links=(
EntityReference(element_name="listing"),
EntityReference(element_name="user"),
),
),
),
)
)

assert dataflow_plan.source_semantic_models == frozenset(
[
SemanticModelReference(semantic_model_name="bookings_source"),
SemanticModelReference(semantic_model_name="listings_latest"),
SemanticModelReference(semantic_model_name="users_latest"),
]
)

0 comments on commit 872c0c0

Please sign in to comment.