From 5c314d48213a8f8ebd8c52fec2ae2ebea1e3e769 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Thu, 9 May 2024 00:15:42 -0700 Subject: [PATCH] Simplify `OptimizeBranchResult`. --- .../source_scan/source_scan_optimizer.py | 51 +++---------------- 1 file changed, 7 insertions(+), 44 deletions(-) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index 81673ae9fa..d669a7c8dc 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -41,18 +41,7 @@ @dataclass(frozen=True) class OptimizeBranchResult: # noqa: D101 - optimized_branch: Optional[DataflowPlanNode] = None - sink_node: Optional[DataflowPlanNode] = None - - @property - def checked_base_output(self) -> DataflowPlanNode: # noqa: D102 - assert self.optimized_branch, f"Expected the result of traversal to produce a {DataflowPlanNode}" - return self.optimized_branch - - @property - def checked_sink_node(self) -> DataflowPlanNode: # noqa: D102 - assert self.sink_node, f"Expected the result of traversal to produce a {DataflowPlanNode}" - return self.sink_node + optimized_branch: DataflowPlanNode @dataclass(frozen=True) @@ -133,19 +122,7 @@ def _default_base_output_handler( ) # Parents should always be DataflowPlanNode return OptimizeBranchResult( - optimized_branch=node.with_new_parents(tuple(x.checked_base_output for x in optimized_parents)) - ) - - def _default_sink_node_handler( - self, - node: DataflowPlanNode, - ) -> OptimizeBranchResult: - optimized_parents: Sequence[OptimizeBranchResult] = tuple( - parent_node.accept(self) for parent_node in node.parent_nodes - ) - # Parents should always be DataflowPlanNode - return OptimizeBranchResult( - sink_node=node.with_new_parents(tuple(x.checked_base_output for x in optimized_parents)) + optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents)) ) def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D102 @@ -188,11 +165,11 @@ def visit_write_to_result_dataframe_node( # noqa: D102 self, node: WriteToResultDataframeNode ) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) - return self._default_sink_node_handler(node) + return self._default_base_output_handler(node) def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) - return self._default_sink_node_handler(node) + return self._default_base_output_handler(node) def visit_filter_elements_node(self, node: FilterElementsNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) @@ -248,15 +225,6 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 # Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG. for parent_branch in node.parent_nodes: result: OptimizeBranchResult = parent_branch.accept(self) - - assert result.sink_node is None, ( - f"Traversing the parents of of {node.__class__.__name__} should not have produced any " - f"{DataflowPlanNode.__class__.__name__} nodes" - ) - - assert ( - result.optimized_branch is not None - ), f"Traversing the parents of a CombineAggregatedOutputsNode should always produce a DataflowPlanNode. Got: {result}" optimized_parent_branches.append(result.optimized_branch) # Try to combine (using ComputeMetricsBranchCombiner) as many parent branches as possible in a @@ -320,17 +288,12 @@ def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D102 msg=f"Optimized:\n\n" f"{dataflow_plan.checked_sink_node.structure_text()}\n\n" f"to:\n\n" - f"{optimized_result.checked_sink_node.structure_text()}", + f"{optimized_result.optimized_branch.structure_text()}", ) - if optimized_result.sink_node: - return DataflowPlan( - plan_id=DagId.from_id_prefix(StaticIdPrefix.OPTIMIZED_DATAFLOW_PLAN_PREFIX), - sink_nodes=[optimized_result.sink_node], - ) - logger.log(level=self._log_level, msg="Optimizer didn't produce a result, so returning the same plan") return DataflowPlan( - sink_nodes=[dataflow_plan.checked_sink_node], + plan_id=DagId.from_id_prefix(StaticIdPrefix.OPTIMIZED_DATAFLOW_PLAN_PREFIX), + sink_nodes=[optimized_result.optimized_branch], ) def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBranchResult: # noqa: D102