diff --git a/metricflow/dataflow/dataflow_plan_analyzer.py b/metricflow/dataflow/dataflow_plan_analyzer.py index 062327784..71a9ea584 100644 --- a/metricflow/dataflow/dataflow_plan_analyzer.py +++ b/metricflow/dataflow/dataflow_plan_analyzer.py @@ -1,13 +1,17 @@ from __future__ import annotations +import logging from collections import defaultdict from typing import Dict, FrozenSet, Mapping, Sequence, Set +from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from typing_extensions import override from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler +logger = logging.getLogger(__name__) + class DataflowPlanAnalyzer: """Class to determine more complex properties of the dataflow plan. diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index e87387385..cad82f215 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -132,7 +132,7 @@ def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107 self._current_left_node: DataflowPlanNode = left_branch_node def _log_visit_node_type(self, node: DataflowPlanNode) -> None: - logger.debug(LazyFormat(lambda: f"Visiting {node}")) + logger.debug(lambda: f"Visiting {node.node_id}") def _log_combine_failure( self, @@ -142,8 +142,10 @@ def _log_combine_failure( ) -> None: logger.debug( LazyFormat( - lambda: f"Because {combine_failure_reason}, unable to combine nodes " - f"left_node={left_node} right_node={right_node}", + "Unable to combine nodes", + combine_failure_reason=combine_failure_reason, + left_node=left_node.node_id, + right_node=right_node.node_id, ) ) @@ -154,7 +156,12 @@ def _log_combine_success( combined_node: DataflowPlanNode, ) -> None: logger.debug( - LazyFormat(lambda: f"Combined left_node={left_node} right_node={right_node} combined_node: {combined_node}") + LazyFormat( + "Successfully combined nodes", + left_node=left_node.node_id, + right_node=right_node.node_id, + combined_node=combined_node.node_id, + ) ) def _combine_parent_branches(self, current_right_node: DataflowPlanNode) -> Optional[Sequence[DataflowPlanNode]]: diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index dd1fe0446..5d9121ed6 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence from metricflow_semantics.dag.id_prefix import StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId @@ -110,20 +110,34 @@ class SourceScanOptimizer( parents. """ + def __init__(self) -> None: # noqa: D107 + self._node_to_result: Dict[DataflowPlanNode, OptimizeBranchResult] = {} + def _log_visit_node_type(self, node: DataflowPlanNode) -> None: - logger.debug(LazyFormat(lambda: f"Visiting {node}")) + logger.debug(LazyFormat(lambda: f"Visiting {node.node_id}")) def _default_base_output_handler( self, node: DataflowPlanNode, ) -> OptimizeBranchResult: + memoized_result = self._node_to_result.get(node) + if memoized_result is not None: + return memoized_result + optimized_parents: Sequence[OptimizeBranchResult] = tuple( parent_node.accept(self) for parent_node in node.parent_nodes ) - # Parents should always be DataflowPlanNode - return OptimizeBranchResult( - optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents)) - ) + + # If no optimization is done, use the same nodes so that common operations can be identified for CTE generation. + if tuple(node.parent_nodes) == optimized_parents: + result = OptimizeBranchResult(optimized_branch=node) + else: + result = OptimizeBranchResult( + optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents)) + ) + + self._node_to_result[node] = result + return result def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) @@ -144,9 +158,14 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> Opti def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) # Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG. + + memoized_result = self._node_to_result.get(node) + if memoized_result is not None: + return memoized_result + optimized_parent_result: OptimizeBranchResult = node.parent_node.accept(self) if optimized_parent_result.optimized_branch is not None: - return OptimizeBranchResult( + result = OptimizeBranchResult( optimized_branch=ComputeMetricsNode.create( parent_node=optimized_parent_result.optimized_branch, metric_specs=node.metric_specs, @@ -154,8 +173,11 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranch aggregated_to_elements=node.aggregated_to_elements, ) ) + else: + result = OptimizeBranchResult(optimized_branch=node) - return OptimizeBranchResult(optimized_branch=node) + self._node_to_result[node] = result + return result def visit_order_by_limit_node(self, node: OrderByLimitNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) @@ -220,11 +242,16 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 self, node: CombineAggregatedOutputsNode ) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) - # The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or CombineAggregatedOutputsNodes + memoized_result = self._node_to_result.get(node) + if memoized_result is not None: + return memoized_result + + # The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or + # CombineAggregatedOutputsNodes. # Stores the result of running this optimizer on each parent branch separately. optimized_parent_branches = [] - logger.debug(LazyFormat(lambda: f"{node} has {len(node.parent_nodes)} parent branches")) + logger.debug(LazyFormat(lambda: f"{node.node_id} has {len(node.parent_nodes)} parent branches")) # Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG. for parent_branch in node.parent_nodes: @@ -257,14 +284,17 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 logger.debug(lambda: f"Got {len(combined_parent_branches)} branches after combination") assert len(combined_parent_branches) > 0 - # If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's no need - # for a CombineAggregatedOutputsNode. + # If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's + # no need for a CombineAggregatedOutputsNode. if len(combined_parent_branches) == 1: - return OptimizeBranchResult(optimized_branch=combined_parent_branches[0]) + result = OptimizeBranchResult(optimized_branch=combined_parent_branches[0]) + else: + result = OptimizeBranchResult( + optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches) + ) - return OptimizeBranchResult( - optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches) - ) + self._node_to_result[node] = result + return result def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) @@ -289,11 +319,10 @@ def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D102 logger.debug( LazyFormat( - lambda: f"Optimized:\n\n" - f"{dataflow_plan.sink_node.structure_text()}\n\n" - f"to:\n\n" - f"{optimized_result.optimized_branch.structure_text()}", - ), + "Optimized dataflow plan", + original_plan=dataflow_plan.sink_node.structure_text(), + optimized_plan=optimized_result.optimized_branch.structure_text(), + ) ) return DataflowPlan(