Skip to content

Commit

Permalink
Add method to figure out the common branches in a dataflow plan.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 10, 2024
1 parent 3b170d1 commit daa0ed6
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 1 deletion.
2 changes: 1 addition & 1 deletion metricflow-semantics/metricflow_semantics/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import TypeVar

VisitorOutputT = TypeVar("VisitorOutputT")
VisitorOutputT = TypeVar("VisitorOutputT", covariant=True)


class Visitable(ABC):
Expand Down
77 changes: 77 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from collections import defaultdict
from typing import Dict, FrozenSet, Mapping, Sequence, Set

from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler


class DataflowPlanAnalyzer:
"""CLass to determine more complex properties of the dataflow plan.
These could also be made as member methods of the dataflow plan, but this requires resolving some circular
dependency issues to break out the functionality into separate files.
"""

@staticmethod
def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNode]:
"""Starting from the sink node, find the common branches that exist in the associated DAG.
Returns a sorted sequence for reproducibility.
"""
counting_visitor = _CountCommonDataflowNodeVisitor()
dataflow_plan.sink_node.accept(counting_visitor)

node_to_common_count = counting_visitor.get_node_counts()

common_nodes = []
for node, count in node_to_common_count.items():
if count > 1:
common_nodes.append(node)

common_branches_visitor = _FindLargestCommonBranchesVisitor(frozenset(common_nodes))

return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))


class _CountCommonDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plans."""

def __init__(self) -> None:
self._node_to_count: Dict[DataflowPlanNode, int] = defaultdict(int)

def get_node_counts(self) -> Mapping[DataflowPlanNode, int]:
return self._node_to_count

@override
def _default_handler(self, node: DataflowPlanNode) -> None:
for parent_node in node.parent_nodes:
parent_node.accept(self)
self._node_to_count[node] += 1


class _FindLargestCommonBranchesVisitor(DataflowPlanNodeVisitorWithDefaultHandler[FrozenSet[DataflowPlanNode]]):
"""Given the nodes that are known to appear in the DAG multiple times, find the common branches.
To get the largest common branches, (e.g. for `A -> B -> C -> D` and `B -> C -> D`, both `B -> C -> D`
and `C -> D` can be considered common branches), this uses preorder traversal and returns the first common node
that is seen.
"""

def __init__(self, common_nodes: FrozenSet[DataflowPlanNode]) -> None:
self._common_nodes = common_nodes

@override
def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode]:
if node in self._common_nodes:
return frozenset({node})

common_branch_leaf_nodes: Set[DataflowPlanNode] = set()

for parent_node in node.parent_nodes:
common_branch_leaf_nodes.update(parent_node.accept(self))

return frozenset(common_branch_leaf_nodes)
50 changes: 50 additions & 0 deletions tests_metricflow/sql/test_common_dataflow_branches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import logging

from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict
from metricflow_semantics.query.query_parser import MetricFlowQueryParser
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import (
assert_str_snapshot_equal,
)

from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer

logger = logging.getLogger(__name__)


def test_shared_metric_query(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_association_resolver: ColumnAssociationResolver,
dataflow_plan_builder: DataflowPlanBuilder,
query_parser: MetricFlowQueryParser,
) -> None:
"""Check common branches in a query that uses derived metrics defined from metrics that are also in the query."""
parse_result = query_parser.parse_and_validate_query(
metric_names=("bookings", "bookings_per_booker"),
group_by_names=("metric_time",),
)
dataflow_plan = dataflow_plan_builder.build_plan(parse_result.query_spec)

obj_dict = {
"dataflow_plan": dataflow_plan.structure_text(),
}

common_branch_leaf_nodes = DataflowPlanAnalyzer.find_common_branches(dataflow_plan)
for i, common_branch_leaf_node in enumerate(sorted(common_branch_leaf_nodes)):
obj_dict[f"common_branch_{i}"] = common_branch_leaf_node.structure_text()

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=mf_pformat_dict(
obj_dict=obj_dict,
preserve_raw_strings=True,
),
)

0 comments on commit daa0ed6

Please sign in to comment.