-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a method to figure out common nodes in a dataflow plan (#1520)
This PR adds a method to figure out nodes in a dataflow plan that appear more than once. i.e. a node that is the parent of multiple nodes. These common nodes indicate operations where a computation is reused, e.g. a metric that is used in the computation of multiple derived metrics in a query. These nodes will be later used to generate CTEs.
- Loading branch information
Showing
9 changed files
with
457 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from typing import Dict, FrozenSet, Mapping, Sequence, Set | ||
|
||
from typing_extensions import override | ||
|
||
from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode | ||
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler | ||
|
||
|
||
class DataflowPlanAnalyzer: | ||
"""Class to determine more complex properties of the dataflow plan. | ||
These could also be made as member methods of the dataflow plan, but this requires resolving some circular | ||
dependency issues to break out the functionality into separate files. | ||
""" | ||
|
||
@staticmethod | ||
def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNode]: | ||
"""Starting from the sink node, find the common branches that exist in the associated DAG. | ||
Returns a sorted sequence for reproducibility. | ||
""" | ||
counting_visitor = _CountDataflowNodeVisitor() | ||
dataflow_plan.sink_node.accept(counting_visitor) | ||
|
||
node_to_common_count = counting_visitor.get_node_counts() | ||
|
||
common_nodes = [] | ||
for node, count in node_to_common_count.items(): | ||
if count > 1: | ||
common_nodes.append(node) | ||
|
||
common_branches_visitor = _FindLargestCommonBranchesVisitor(frozenset(common_nodes)) | ||
|
||
return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor))) | ||
|
||
|
||
class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]): | ||
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan.""" | ||
|
||
def __init__(self) -> None: | ||
self._node_to_count: Dict[DataflowPlanNode, int] = defaultdict(int) | ||
|
||
def get_node_counts(self) -> Mapping[DataflowPlanNode, int]: | ||
return self._node_to_count | ||
|
||
@override | ||
def _default_handler(self, node: DataflowPlanNode) -> None: | ||
for parent_node in node.parent_nodes: | ||
parent_node.accept(self) | ||
self._node_to_count[node] += 1 | ||
|
||
|
||
class _FindLargestCommonBranchesVisitor(DataflowPlanNodeVisitorWithDefaultHandler[FrozenSet[DataflowPlanNode]]): | ||
"""Given the nodes that are known to appear in the DAG multiple times, find the common branches. | ||
To get the largest common branches, (e.g. for `A -> B -> C -> D` and `B -> C -> D`, both `B -> C -> D` | ||
and `C -> D` can be considered common branches, and we want the largest one), this uses preorder traversal and | ||
returns the first common node that is seen. | ||
""" | ||
|
||
def __init__(self, common_nodes: FrozenSet[DataflowPlanNode]) -> None: | ||
self._common_nodes = common_nodes | ||
|
||
@override | ||
def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode]: | ||
# Traversal starts from the leaf node and then goes to the parent branches. By doing this check first, we don't | ||
# return smaller common branches that are a part of a larger common branch. | ||
if node in self._common_nodes: | ||
return frozenset({node}) | ||
|
||
common_branch_leaf_nodes: Set[DataflowPlanNode] = set() | ||
|
||
for parent_node in node.parent_nodes: | ||
common_branch_leaf_nodes.update(parent_node.accept(self)) | ||
|
||
return frozenset(common_branch_leaf_nodes) |
Oops, something went wrong.