From 7306f319f36b93d3c47715bf96063812b3ed1b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 27 Nov 2023 13:35:35 +0100 Subject: [PATCH] Dynamically optimize aggregate based on shuffle stats --- .../src/state/execution_graph/execution_stage.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ballista/scheduler/src/state/execution_graph/execution_stage.rs b/ballista/scheduler/src/state/execution_graph/execution_stage.rs index fcac54d5d..f082fe435 100644 --- a/ballista/scheduler/src/state/execution_graph/execution_stage.rs +++ b/ballista/scheduler/src/state/execution_graph/execution_stage.rs @@ -22,6 +22,7 @@ use std::iter::FromIterator; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use datafusion::physical_optimizer::aggregate_statistics::AggregateStatistics; use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::display::DisplayableExecutionPlan; @@ -348,9 +349,13 @@ impl UnresolvedStage { &input_locations, )?; - // Optimize join order based on new resolved statistics + // Optimize join order and statistics based on new resolved statistics let optimize_join = JoinSelection::new(); - let plan = optimize_join.optimize(plan, SessionConfig::default().options())?; + let config = SessionConfig::default(); + let plan = optimize_join.optimize(plan, config.options())?; + let optimize_aggregate = AggregateStatistics::new(); + let plan = + optimize_aggregate.optimize(plan, SessionConfig::default().options())?; Ok(ResolvedStage::new( self.stage_id,