Skip to content

Commit

Permalink
/* PR_START p--smr 06 */ Create DataflowPlanLookup.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed May 15, 2024
1 parent 4feb4d6 commit a784bdf
Showing 1 changed file with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
logger = logging.getLogger(__name__)


class ReadSqlSourceNodeCounter(DataflowDagWalker[int]):
class _ReadSqlSourceNodeCounter(DataflowDagWalker[int]):
"""Counts the number of ReadSqlSourceNodes in the dataflow plan."""

@override
Expand All @@ -43,8 +43,16 @@ def default_visit_action(self, current_node: DataflowPlanNode, inputs: Sequence[
def visit_source_node(self, node: ReadSqlSourceNode) -> int: # noqa: D102
return 1

def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102
return dataflow_plan.checked_sink_node.accept(self)

class DataflowPlanLookup:
"""A lookup class to get assorted properties about the dataflow plan."""

def __init__(self, dataflow_plan: DataflowPlan) -> None: # noqa: D107
self._dataflow_plan_sink_node = dataflow_plan.checked_sink_node

def source_node_count(self) -> int:
"""Counts the number of `ReadSqlSourceNodes` in the dataflow plan."""
return self._dataflow_plan_sink_node.accept(_ReadSqlSourceNodeCounter())


def check_optimization( # noqa: D103
Expand All @@ -70,8 +78,8 @@ def check_optimization( # noqa: D103
dag_graph=dataflow_plan,
)

source_counter = ReadSqlSourceNodeCounter()
assert source_counter.count_source_nodes(dataflow_plan) == expected_num_sources_in_unoptimized
dataflow_plan_lookup = DataflowPlanLookup(dataflow_plan)
assert dataflow_plan_lookup.source_node_count() == expected_num_sources_in_unoptimized

optimizer = SourceScanOptimizer()
optimized_dataflow_plan = optimizer.optimize(dataflow_plan)
Expand All @@ -88,7 +96,9 @@ def check_optimization( # noqa: D103
mf_test_configuration=mf_test_configuration,
dag_graph=optimized_dataflow_plan,
)
assert source_counter.count_source_nodes(optimized_dataflow_plan) == expected_num_sources_in_optimized

optimized_dataflow_plan_lookup = DataflowPlanLookup(optimized_dataflow_plan)
assert optimized_dataflow_plan_lookup.source_node_count() == expected_num_sources_in_optimized


@pytest.mark.sql_engine_snapshot
Expand Down

0 comments on commit a784bdf

Please sign in to comment.