Skip to content

Commit

Permalink
/* PR_START p--cte 19 */ Add method to group nodes by type.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 13, 2024
1 parent 11ec9f8 commit 23ba0da
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, FrozenSet, Mapping, Sequence, Set

from metricflow_semantics.collection_helpers.merger import Mergeable
from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode


class DataflowPlanAnalyzer:
Expand Down Expand Up @@ -36,6 +39,12 @@ def find_common_branches(dataflow_plan: DataflowPlan) -> Sequence[DataflowPlanNo

return tuple(sorted(dataflow_plan.sink_node.accept(common_branches_visitor)))

@staticmethod
def group_nodes_by_type(dataflow_plan: DataflowPlan) -> DataflowPlanNodeSet:
"""Grouops dataflow plan nodes by type."""
grouping_visitor = _GroupNodesByTypeVisitor()
return dataflow_plan.sink_node.accept(grouping_visitor)


class _CountDataflowNodeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[None]):
"""Helper visitor to build a dict from a node in the plan to the number of times it appears in the plan."""
Expand Down Expand Up @@ -77,3 +86,41 @@ def _default_handler(self, node: DataflowPlanNode) -> FrozenSet[DataflowPlanNode
common_branch_leaf_nodes.update(parent_node.accept(self))

return frozenset(common_branch_leaf_nodes)


@dataclass(frozen=True)
class DataflowPlanNodeSet(Mergeable):
"""Contains a set of dataflow plan nodes with fields for different types.
`ComputeMetricsNode` is the only node of interest for current use cases, but fields for other types can be added
later.
"""

compute_metric_nodes: FrozenSet[ComputeMetricsNode]

def merge(self, other: DataflowPlanNodeSet) -> DataflowPlanNodeSet:
return DataflowPlanNodeSet(
compute_metric_nodes=self.compute_metric_nodes.union(other.compute_metric_nodes),
)

@classmethod
def empty_instance(cls) -> DataflowPlanNodeSet:
return DataflowPlanNodeSet(
compute_metric_nodes=frozenset(),
)


class _GroupNodesByTypeVisitor(DataflowPlanNodeVisitorWithDefaultHandler[DataflowPlanNodeSet]):
"""Groups dataflow nodes by type."""

@override
def _default_handler(self, node: DataflowPlanNode) -> DataflowPlanNodeSet:
node_sets = []
for parent_node in node.parent_nodes:
node_sets.append(parent_node.accept(self))

return DataflowPlanNodeSet.merge_iterable(node_sets)

@override
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DataflowPlanNodeSet:
return self._default_handler(node).merge(DataflowPlanNodeSet(frozenset({node})))

0 comments on commit 23ba0da

Please sign in to comment.