diff --git a/metricflow-semantics/metricflow_semantics/visitor.py b/metricflow-semantics/metricflow_semantics/visitor.py index bdc220f09..6f795c784 100644 --- a/metricflow-semantics/metricflow_semantics/visitor.py +++ b/metricflow-semantics/metricflow_semantics/visitor.py @@ -3,7 +3,7 @@ from abc import ABC from typing import TypeVar -VisitorOutputT = TypeVar("VisitorOutputT") +VisitorOutputT = TypeVar("VisitorOutputT", covariant=True) class Visitable(ABC): diff --git a/metricflow/dataflow/dataflow_plan_analyzer.py b/metricflow/dataflow/dataflow_plan_analyzer.py new file mode 100644 index 000000000..be01385ec --- /dev/null +++ b/metricflow/dataflow/dataflow_plan_analyzer.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Dict, FrozenSet, Mapping, Sequence, Set + +from typing_extensions import override + +from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode +from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler + + +class DataflowPlanAnalyzer: + """CLass to determine more complex properties of the dataflow plan. + + These could also be made as member methods of the dataflow plan, but this requires resolving some circular + dependency issues to break out the functionality into separate files. + """ + + @staticmethod + def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNode]: + """Starting from the sink node, find the common branches that exist in the associated DAG. + + Returns a sorted sequence for reproducibility. + """ + counting_visitor = _CountCommonDataflowNodeVisitor() + dataflow_plan.sink_node.accept(counting_visitor) + + node_to_common_count = counting_visitor.get_node_counts() + + common_nodes = [] + for node, count in node_to_common_count.items(): + if count > 1: + common_nodes.append(node) + + common_branches_visitor = _FindLargestCommonBranchesVisitor(frozenset(common_nodes)) + + return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor))) + + +class _CountCommonDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]): + """Helper visitor to build a dict from a node in the plan to the number of times it appears in the plans.""" + + def __init__(self) -> None: + self._node_to_count: Dict[DataflowPlanNode, int] = defaultdict(int) + + def get_node_counts(self) -> Mapping[DataflowPlanNode, int]: + return self._node_to_count + + @override + def _default_handler(self, node: DataflowPlanNode) -> None: + for parent_node in node.parent_nodes: + parent_node.accept(self) + self._node_to_count[node] += 1 + + +class _FindLargestCommonBranchesVisitor(DataflowPlanNodeVisitorWithDefaultHandler[FrozenSet[DataflowPlanNode]]): + """Given the nodes that are known to appear in the DAG multiple times, find the common branches. + + To get the largest common branches, (e.g. for `A -> B -> C -> D` and `B -> C -> D`, both `B -> C -> D` + and `C -> D` can be considered common branches), this uses preorder traversal and returns the first common node + that is seen. + """ + + def __init__(self, common_nodes: FrozenSet[DataflowPlanNode]) -> None: + self._common_nodes = common_nodes + + @override + def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode]: + if node in self._common_nodes: + return frozenset({node}) + + common_branch_leaf_nodes: Set[DataflowPlanNode] = set() + + for parent_node in node.parent_nodes: + common_branch_leaf_nodes.update(parent_node.accept(self)) + + return frozenset(common_branch_leaf_nodes) diff --git a/tests_metricflow/sql/test_common_dataflow_branches.py b/tests_metricflow/sql/test_common_dataflow_branches.py new file mode 100644 index 000000000..d15c59f31 --- /dev/null +++ b/tests_metricflow/sql/test_common_dataflow_branches.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import logging + +from _pytest.fixtures import FixtureRequest +from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict +from metricflow_semantics.query.query_parser import MetricFlowQueryParser +from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration +from metricflow_semantics.test_helpers.snapshot_helpers import ( + assert_str_snapshot_equal, +) + +from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder +from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer + +logger = logging.getLogger(__name__) + + +def test_shared_metric_query( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + column_association_resolver: ColumnAssociationResolver, + dataflow_plan_builder: DataflowPlanBuilder, + query_parser: MetricFlowQueryParser, +) -> None: + """Check common branches in a query that uses derived metrics defined from metrics that are also in the query.""" + parse_result = query_parser.parse_and_validate_query( + metric_names=("bookings", "bookings_per_booker"), + group_by_names=("metric_time",), + ) + dataflow_plan = dataflow_plan_builder.build_plan(parse_result.query_spec) + + obj_dict = { + "dataflow_plan": dataflow_plan.structure_text(), + } + + common_branch_leaf_nodes = DataflowPlanAnalyzer.find_common_branches(dataflow_plan) + for i, common_branch_leaf_node in enumerate(sorted(common_branch_leaf_nodes)): + obj_dict[f"common_branch_{i}"] = common_branch_leaf_node.structure_text() + + assert_str_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + snapshot_id="result", + snapshot_str=mf_pformat_dict( + obj_dict=obj_dict, + preserve_raw_strings=True, + ), + )