-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add method to figure out the common branches in a dataflow plan.
- Loading branch information
Showing
3 changed files
with
132 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from typing import Dict, FrozenSet, Mapping, Sequence, Set | ||
|
||
from typing_extensions import override | ||
|
||
from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode | ||
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler | ||
|
||
|
||
class DataflowPlanAnalyzer: | ||
"""CLass to determine more complex properties of the dataflow plan. | ||
These could also be made as member methods of the dataflow plan, but this requires resolving some circular | ||
dependency issues to break out the functionality into separate files. | ||
""" | ||
|
||
@staticmethod | ||
def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNode]: | ||
"""Starting from the sink node, find the common branches that exist in the associated DAG. | ||
Returns a sorted sequence for reproducibility. | ||
""" | ||
counting_visitor = _CountCommonDataflowNodeVisitor() | ||
dataflow_plan.sink_node.accept(counting_visitor) | ||
|
||
node_to_common_count = counting_visitor.get_node_counts() | ||
|
||
common_nodes = [] | ||
for node, count in node_to_common_count.items(): | ||
if count > 1: | ||
common_nodes.append(node) | ||
|
||
common_branches_visitor = _FindLargestCommonBranchesVisitor(frozenset(common_nodes)) | ||
|
||
return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor))) | ||
|
||
|
||
class _CountCommonDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]): | ||
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plans.""" | ||
|
||
def __init__(self) -> None: | ||
self._node_to_count: Dict[DataflowPlanNode, int] = defaultdict(int) | ||
|
||
def get_node_counts(self) -> Mapping[DataflowPlanNode, int]: | ||
return self._node_to_count | ||
|
||
@override | ||
def _default_handler(self, node: DataflowPlanNode) -> None: | ||
for parent_node in node.parent_nodes: | ||
parent_node.accept(self) | ||
self._node_to_count[node] += 1 | ||
|
||
|
||
class _FindLargestCommonBranchesVisitor(DataflowPlanNodeVisitorWithDefaultHandler[FrozenSet[DataflowPlanNode]]): | ||
"""Given the nodes that are known to appear in the DAG multiple times, find the common branches. | ||
To get the largest common branches, (e.g. for `A -> B -> C -> D` and `B -> C -> D`, both `B -> C -> D` | ||
and `C -> D` can be considered common branches), this uses preorder traversal and returns the first common node | ||
that is seen. | ||
""" | ||
|
||
def __init__(self, common_nodes: FrozenSet[DataflowPlanNode]) -> None: | ||
self._common_nodes = common_nodes | ||
|
||
@override | ||
def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode]: | ||
if node in self._common_nodes: | ||
return frozenset({node}) | ||
|
||
common_branch_leaf_nodes: Set[DataflowPlanNode] = set() | ||
|
||
for parent_node in node.parent_nodes: | ||
common_branch_leaf_nodes.update(parent_node.accept(self)) | ||
|
||
return frozenset(common_branch_leaf_nodes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
from _pytest.fixtures import FixtureRequest | ||
from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict | ||
from metricflow_semantics.query.query_parser import MetricFlowQueryParser | ||
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver | ||
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration | ||
from metricflow_semantics.test_helpers.snapshot_helpers import ( | ||
assert_str_snapshot_equal, | ||
) | ||
|
||
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder | ||
from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def test_shared_metric_query( | ||
request: FixtureRequest, | ||
mf_test_configuration: MetricFlowTestConfiguration, | ||
column_association_resolver: ColumnAssociationResolver, | ||
dataflow_plan_builder: DataflowPlanBuilder, | ||
query_parser: MetricFlowQueryParser, | ||
) -> None: | ||
"""For a known case, test that a metric computation node is identified as a common branch. | ||
A query for `bookings` and `bookings_per_booker` should have the computation for `bookings` as a common branch in | ||
the dataflow plan. | ||
""" | ||
parse_result = query_parser.parse_and_validate_query( | ||
metric_names=("bookings", "bookings_per_booker"), | ||
group_by_names=("metric_time",), | ||
) | ||
dataflow_plan = dataflow_plan_builder.build_plan(parse_result.query_spec) | ||
|
||
obj_dict = { | ||
"dataflow_plan": dataflow_plan.structure_text(), | ||
} | ||
|
||
common_branch_leaf_nodes = DataflowPlanAnalyzer.find_common_branches(dataflow_plan) | ||
for i, common_branch_leaf_node in enumerate(sorted(common_branch_leaf_nodes)): | ||
obj_dict[f"common_branch_{i}"] = common_branch_leaf_node.structure_text() | ||
|
||
assert_str_snapshot_equal( | ||
request=request, | ||
mf_test_configuration=mf_test_configuration, | ||
snapshot_id="result", | ||
snapshot_str=mf_pformat_dict( | ||
obj_dict=obj_dict, | ||
preserve_raw_strings=True, | ||
), | ||
) |