Skip to content

Commit

Permalink
Update SourceScanOptimizer to memoize results.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Dec 11, 2024
1 parent aa6ec15 commit e233c28
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 25 deletions.
4 changes: 4 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Dict, FrozenSet, Mapping, Sequence, Set

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler

logger = logging.getLogger(__name__)


class DataflowPlanAnalyzer:
"""Class to determine more complex properties of the dataflow plan.
Expand Down
15 changes: 11 additions & 4 deletions metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107
self._current_left_node: DataflowPlanNode = left_branch_node

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.debug(LazyFormat(lambda: f"Visiting {node}"))
logger.debug(lambda: f"Visiting {node.node_id}")

def _log_combine_failure(
self,
Expand All @@ -142,8 +142,10 @@ def _log_combine_failure(
) -> None:
logger.debug(
LazyFormat(
lambda: f"Because {combine_failure_reason}, unable to combine nodes "
f"left_node={left_node} right_node={right_node}",
"Unable to combine nodes",
combine_failure_reason=combine_failure_reason,
left_node=left_node.node_id,
right_node=right_node.node_id,
)
)

Expand All @@ -154,7 +156,12 @@ def _log_combine_success(
combined_node: DataflowPlanNode,
) -> None:
logger.debug(
LazyFormat(lambda: f"Combined left_node={left_node} right_node={right_node} combined_node: {combined_node}")
LazyFormat(
"Successfully combined nodes",
left_node=left_node.node_id,
right_node=right_node.node_id,
combined_node=combined_node.node_id,
)
)

def _combine_parent_branches(self, current_right_node: DataflowPlanNode) -> Optional[Sequence[DataflowPlanNode]]:
Expand Down
71 changes: 50 additions & 21 deletions metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from dataclasses import dataclass
from typing import List, Optional, Sequence
from typing import Dict, List, Optional, Sequence

from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId
Expand Down Expand Up @@ -110,20 +110,34 @@ class SourceScanOptimizer(
parents.
"""

def __init__(self) -> None: # noqa: D107
self._node_to_result: Dict[DataflowPlanNode, OptimizeBranchResult] = {}

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.debug(LazyFormat(lambda: f"Visiting {node}"))
logger.debug(LazyFormat(lambda: f"Visiting {node.node_id}"))

def _default_base_output_handler(
self,
node: DataflowPlanNode,
) -> OptimizeBranchResult:
memoized_result = self._node_to_result.get(node)
if memoized_result is not None:
return memoized_result

optimized_parents: Sequence[OptimizeBranchResult] = tuple(
parent_node.accept(self) for parent_node in node.parent_nodes
)
# Parents should always be DataflowPlanNode
return OptimizeBranchResult(
optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents))
)

# If no optimization is done, use the same nodes so that common operations can be identified for CTE generation.
if tuple(node.parent_nodes) == optimized_parents:
result = OptimizeBranchResult(optimized_branch=node)
else:
result = OptimizeBranchResult(
optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents))
)

self._node_to_result[node] = result
return result

def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
Expand All @@ -144,18 +158,26 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> Opti
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
# Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG.

memoized_result = self._node_to_result.get(node)
if memoized_result is not None:
return memoized_result

optimized_parent_result: OptimizeBranchResult = node.parent_node.accept(self)
if optimized_parent_result.optimized_branch is not None:
return OptimizeBranchResult(
result = OptimizeBranchResult(
optimized_branch=ComputeMetricsNode.create(
parent_node=optimized_parent_result.optimized_branch,
metric_specs=node.metric_specs,
for_group_by_source_node=node.for_group_by_source_node,
aggregated_to_elements=node.aggregated_to_elements,
)
)
else:
result = OptimizeBranchResult(optimized_branch=node)

return OptimizeBranchResult(optimized_branch=node)
self._node_to_result[node] = result
return result

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
Expand Down Expand Up @@ -220,11 +242,16 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
self, node: CombineAggregatedOutputsNode
) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
# The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or CombineAggregatedOutputsNodes

memoized_result = self._node_to_result.get(node)
if memoized_result is not None:
return memoized_result

# The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or
# CombineAggregatedOutputsNodes.
# Stores the result of running this optimizer on each parent branch separately.
optimized_parent_branches = []
logger.debug(LazyFormat(lambda: f"{node} has {len(node.parent_nodes)} parent branches"))
logger.debug(LazyFormat(lambda: f"{node.node_id} has {len(node.parent_nodes)} parent branches"))

# Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG.
for parent_branch in node.parent_nodes:
Expand Down Expand Up @@ -257,14 +284,17 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
logger.debug(lambda: f"Got {len(combined_parent_branches)} branches after combination")
assert len(combined_parent_branches) > 0

# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's no need
# for a CombineAggregatedOutputsNode.
# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's
# no need for a CombineAggregatedOutputsNode.
if len(combined_parent_branches) == 1:
return OptimizeBranchResult(optimized_branch=combined_parent_branches[0])
result = OptimizeBranchResult(optimized_branch=combined_parent_branches[0])
else:
result = OptimizeBranchResult(
optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches)
)

return OptimizeBranchResult(
optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches)
)
self._node_to_result[node] = result
return result

def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
Expand All @@ -289,11 +319,10 @@ def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D102

logger.debug(
LazyFormat(
lambda: f"Optimized:\n\n"
f"{dataflow_plan.sink_node.structure_text()}\n\n"
f"to:\n\n"
f"{optimized_result.optimized_branch.structure_text()}",
),
"Optimized dataflow plan",
original_plan=dataflow_plan.sink_node.structure_text(),
optimized_plan=optimized_result.optimized_branch.structure_text(),
)
)

return DataflowPlan(
Expand Down

0 comments on commit e233c28

Please sign in to comment.