Skip to content

Commit

Permalink
Simplify existing cases with DataflowPlan walker.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed May 15, 2024
1 parent 4d89bc2 commit ac4c0d3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 439 deletions.
167 changes: 33 additions & 134 deletions metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,15 @@
from typing import List, Optional, Sequence

from metricflow_semantics.specs.spec_classes import MetricSpec
from typing_extensions import override

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
DataflowPlanNodeVisitor,
)
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.dfs_walker import DataflowDagWalker
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform

logger = logging.getLogger(__name__)
Expand All @@ -50,7 +36,7 @@ def checked_combined_branch(self) -> DataflowPlanNode: # noqa: D102
return self.combined_branch


class ComputeMetricsBranchCombiner(DataflowPlanNodeVisitor[ComputeMetricsBranchCombinerResult]):
class ComputeMetricsBranchCombiner(DataflowDagWalker[ComputeMetricsBranchCombinerResult]):
"""Combines branches where the leaf node is a ComputeMetricsNode.
This considers two branches, a left branch and a right branch. The left branch is supplied via the argument in the
Expand Down Expand Up @@ -127,10 +113,7 @@ class ComputeMetricsBranchCombiner(DataflowPlanNodeVisitor[ComputeMetricsBranchC

def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107
self._current_left_node: DataflowPlanNode = left_branch_node
self._log_level = logging.DEBUG

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.log(level=self._log_level, msg=f"Visiting {node}")
super().__init__(visit_log_level=logging.DEBUG, default_action_recursion=False)

def _log_combine_failure(
self,
Expand All @@ -139,7 +122,7 @@ def _log_combine_failure(
combine_failure_reason: str,
) -> None:
logger.log(
level=self._log_level,
level=self._visit_log_level,
msg=f"Because {combine_failure_reason}, unable to combine nodes "
f"left_node={left_node} right_node={right_node}",
)
Expand All @@ -151,22 +134,22 @@ def _log_combine_success(
combined_node: DataflowPlanNode,
) -> None:
logger.log(
level=self._log_level,
level=self._visit_log_level,
msg=f"Combined left_node={left_node} right_node={right_node} combined_node: {combined_node}",
)

def _combine_parent_branches(self, current_right_node: DataflowPlanNode) -> Optional[Sequence[DataflowPlanNode]]:
if len(self._current_left_node.parent_nodes) != len(current_right_node.parent_nodes):
def _combine_parent_branches(self, node: DataflowPlanNode) -> Optional[Sequence[DataflowPlanNode]]:
if len(self._current_left_node.parent_nodes) != len(node.parent_nodes):
self._log_combine_failure(
left_node=self._current_left_node,
right_node=current_right_node,
right_node=node,
combine_failure_reason="parent counts are unequal",
)
return None

results_of_visiting_parent_nodes: List[ComputeMetricsBranchCombinerResult] = []

for i, right_node_parent_node in enumerate(current_right_node.parent_nodes):
for i, right_node_parent_node in enumerate(node.parent_nodes):
left_position_before_recursion = self._current_left_node
self._current_left_node = self._current_left_node.parent_nodes[i]
results_of_visiting_parent_nodes.append(right_node_parent_node.accept(self))
Expand All @@ -177,51 +160,54 @@ def _combine_parent_branches(self, current_right_node: DataflowPlanNode) -> Opti
if result.combined_branch is None:
self._log_combine_failure(
left_node=self._current_left_node,
right_node=current_right_node,
right_node=node,
combine_failure_reason="not all parents could be combined",
)
return None
combined_parents.append(result.combined_branch)

return combined_parents

def _default_handler(self, current_right_node: DataflowPlanNode) -> ComputeMetricsBranchCombinerResult:
combined_parent_nodes = self._combine_parent_branches(current_right_node)
def _default_action(
self, current_node: DataflowPlanNode, inputs: Sequence[ComputeMetricsBranchCombinerResult]
) -> ComputeMetricsBranchCombinerResult:
combined_parent_nodes = self._combine_parent_branches(current_node)
if combined_parent_nodes is None:
return ComputeMetricsBranchCombinerResult()

new_parent_nodes = combined_parent_nodes

# If the parent nodes were combined, and the left node is the same as the right node, then the left and right
# nodes can be combined.
if self._current_left_node.functionally_identical(current_right_node):
combined_node = current_right_node.with_new_parents(new_parent_nodes)
# If the parent nodes were combined, and the left node is the same as the right node, then the left and
# right nodes can be combined.
if self._current_left_node.functionally_identical(current_node):
combined_node = current_node.with_new_parents(new_parent_nodes)
self._log_combine_success(
left_node=self._current_left_node, right_node=current_right_node, combined_node=combined_node
left_node=self._current_left_node, right_node=current_node, combined_node=combined_node
)
return ComputeMetricsBranchCombinerResult(combined_node)

self._log_combine_failure(
left_node=self._current_left_node,
right_node=current_right_node,
right_node=current_node,
combine_failure_reason="there are functional differences",
)
return ComputeMetricsBranchCombinerResult()

def visit_source_node(self, node: ReadSqlSourceNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_on_entities_node( # noqa: D102
self, node: JoinOnEntitiesNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)
@override
def default_visit_action(
self, current_node: DataflowPlanNode, inputs: Sequence[ComputeMetricsBranchCombinerResult]
) -> ComputeMetricsBranchCombinerResult:
self.log_visit_start(current_node, inputs)
result = None
try:
result = self._default_action(current_node, inputs)
return result
finally:
self.log_visit_end(current_node, result)

def visit_aggregate_measures_node( # noqa: D102
self, node: AggregateMeasuresNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
) -> ComputeMetricsBranchCombinerResult:
current_right_node = node

combined_parent_nodes = self._combine_parent_branches(current_right_node)
Expand Down Expand Up @@ -271,7 +257,6 @@ def visit_aggregate_measures_node( # noqa: D102

def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
current_right_node = node
self._log_visit_node_type(current_right_node)
combined_parent_nodes = self._combine_parent_branches(current_right_node)
if combined_parent_nodes is None:
return ComputeMetricsBranchCombinerResult()
Expand Down Expand Up @@ -317,43 +302,9 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetrics
)
return ComputeMetricsBranchCombinerResult(combined_node)

def _handle_unsupported_node(self, current_right_node: DataflowPlanNode) -> ComputeMetricsBranchCombinerResult:
self._log_combine_failure(
left_node=self._current_left_node,
right_node=current_right_node,
combine_failure_reason=(
f"right node is of type {current_right_node.__class__.__name__} which is not yet handled"
),
)
return ComputeMetricsBranchCombinerResult()

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_where_constraint_node( # noqa: D102
self, node: WhereConstraintNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_write_to_result_dataframe_node( # noqa: D102
self, node: WriteToResultDataframeNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_write_to_result_table_node( # noqa: D102
self, node: WriteToResultTableNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_filter_elements_node( # noqa: D102
self, node: FilterElementsNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)

current_right_node = node
results_of_visiting_parent_nodes = self._combine_parent_branches(current_right_node)
if results_of_visiting_parent_nodes is None:
Expand Down Expand Up @@ -394,55 +345,3 @@ def visit_filter_elements_node( # noqa: D102
combined_node=combined_node,
)
return ComputeMetricsBranchCombinerResult(combined_node)

def visit_combine_aggregated_outputs_node( # noqa: D102
self, node: CombineAggregatedOutputsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_constrain_time_range_node( # noqa: D102
self, node: ConstrainTimeRangeNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_over_time_range_node( # noqa: D102
self, node: JoinOverTimeRangeNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_semi_additive_join_node( # noqa: D102
self, node: SemiAdditiveJoinNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_metric_time_dimension_transform_node( # noqa: D102
self, node: MetricTimeDimensionTransformNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_to_time_spine_node( # noqa: D102
self, node: JoinToTimeSpineNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_add_generated_uuid_column_node( # noqa: D102
self, node: AddGeneratedUuidColumnNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_conversion_events_node( # noqa: D102
self, node: JoinConversionEventsNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)
Loading

0 comments on commit ac4c0d3

Please sign in to comment.