Skip to content

Commit

Permalink
Simplify OptimizeBranchResult.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed May 11, 2024
1 parent 9aa3a53 commit 5c314d4
Showing 1 changed file with 7 additions and 44 deletions.
51 changes: 7 additions & 44 deletions metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,7 @@

@dataclass(frozen=True)
class OptimizeBranchResult: # noqa: D101
optimized_branch: Optional[DataflowPlanNode] = None
sink_node: Optional[DataflowPlanNode] = None

@property
def checked_base_output(self) -> DataflowPlanNode: # noqa: D102
assert self.optimized_branch, f"Expected the result of traversal to produce a {DataflowPlanNode}"
return self.optimized_branch

@property
def checked_sink_node(self) -> DataflowPlanNode: # noqa: D102
assert self.sink_node, f"Expected the result of traversal to produce a {DataflowPlanNode}"
return self.sink_node
optimized_branch: DataflowPlanNode


@dataclass(frozen=True)
Expand Down Expand Up @@ -133,19 +122,7 @@ def _default_base_output_handler(
)
# Parents should always be DataflowPlanNode
return OptimizeBranchResult(
optimized_branch=node.with_new_parents(tuple(x.checked_base_output for x in optimized_parents))
)

def _default_sink_node_handler(
self,
node: DataflowPlanNode,
) -> OptimizeBranchResult:
optimized_parents: Sequence[OptimizeBranchResult] = tuple(
parent_node.accept(self) for parent_node in node.parent_nodes
)
# Parents should always be DataflowPlanNode
return OptimizeBranchResult(
sink_node=node.with_new_parents(tuple(x.checked_base_output for x in optimized_parents))
optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents))
)

def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D102
Expand Down Expand Up @@ -188,11 +165,11 @@ def visit_write_to_result_dataframe_node( # noqa: D102
self, node: WriteToResultDataframeNode
) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_sink_node_handler(node)
return self._default_base_output_handler(node)

def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_sink_node_handler(node)
return self._default_base_output_handler(node)

def visit_filter_elements_node(self, node: FilterElementsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
Expand Down Expand Up @@ -248,15 +225,6 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
# Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG.
for parent_branch in node.parent_nodes:
result: OptimizeBranchResult = parent_branch.accept(self)

assert result.sink_node is None, (
f"Traversing the parents of of {node.__class__.__name__} should not have produced any "
f"{DataflowPlanNode.__class__.__name__} nodes"
)

assert (
result.optimized_branch is not None
), f"Traversing the parents of a CombineAggregatedOutputsNode should always produce a DataflowPlanNode. Got: {result}"
optimized_parent_branches.append(result.optimized_branch)

# Try to combine (using ComputeMetricsBranchCombiner) as many parent branches as possible in a
Expand Down Expand Up @@ -320,17 +288,12 @@ def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D102
msg=f"Optimized:\n\n"
f"{dataflow_plan.checked_sink_node.structure_text()}\n\n"
f"to:\n\n"
f"{optimized_result.checked_sink_node.structure_text()}",
f"{optimized_result.optimized_branch.structure_text()}",
)

if optimized_result.sink_node:
return DataflowPlan(
plan_id=DagId.from_id_prefix(StaticIdPrefix.OPTIMIZED_DATAFLOW_PLAN_PREFIX),
sink_nodes=[optimized_result.sink_node],
)
logger.log(level=self._log_level, msg="Optimizer didn't produce a result, so returning the same plan")
return DataflowPlan(
sink_nodes=[dataflow_plan.checked_sink_node],
plan_id=DagId.from_id_prefix(StaticIdPrefix.OPTIMIZED_DATAFLOW_PLAN_PREFIX),
sink_nodes=[optimized_result.optimized_branch],
)

def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBranchResult: # noqa: D102
Expand Down

0 comments on commit 5c314d4

Please sign in to comment.