diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index cb4adcecdc8c..00f7366655e0 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -188,6 +188,8 @@ pub struct HashAggExecutor { chunk_size: usize, mem_context: MemoryContext, enable_spill: bool, + /// The upper bound of memory usage for this executor. + memory_upper_bound: Option, shutdown_rx: ShutdownToken, _phantom: PhantomData, } @@ -205,7 +207,7 @@ impl HashAggExecutor { enable_spill: bool, shutdown_rx: ShutdownToken, ) -> Self { - Self::new_with_init_agg_state( + Self::new_inner( aggs, group_key_columns, group_key_types, @@ -216,12 +218,13 @@ impl HashAggExecutor { chunk_size, mem_context, enable_spill, + None, shutdown_rx, ) } #[allow(clippy::too_many_arguments)] - fn new_with_init_agg_state( + fn new_inner( aggs: Arc>, group_key_columns: Vec, group_key_types: Vec, @@ -232,6 +235,7 @@ impl HashAggExecutor { chunk_size: usize, mem_context: MemoryContext, enable_spill: bool, + memory_upper_bound: Option, shutdown_rx: ShutdownToken, ) -> Self { HashAggExecutor { @@ -245,6 +249,7 @@ impl HashAggExecutor { chunk_size, mem_context, enable_spill, + memory_upper_bound, shutdown_rx, _phantom: PhantomData, } @@ -461,6 +466,22 @@ impl AggSpillManager { Ok(Self::read_stream(r)) } + async fn estimate_partition_size(&self, partition: usize) -> Result { + let agg_state_partition_file_name = format!("agg-state-p{}", partition); + let agg_state_size = self + .op + .stat(&agg_state_partition_file_name) + .await? + .content_length(); + let input_partition_file_name = format!("input-chunks-p{}", partition); + let input_size = self + .op + .stat(&input_partition_file_name) + .await? + .content_length(); + Ok(agg_state_size + input_size) + } + async fn clear_partition(&mut self, partition: usize) -> Result<()> { let agg_state_partition_file_name = format!("agg-state-p{}", partition); self.op.delete(&agg_state_partition_file_name).await?; @@ -470,11 +491,18 @@ impl AggSpillManager { } } +const SPILL_AT_LEAST_MEMORY: u64 = 1024 * 1024; + impl HashAggExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_execute(self: Box) { let child_schema = self.child.schema().clone(); let mut need_to_spill = false; + // If the memory upper bound is less than 1MB, we don't need to check memory usage. + let check_memory = match self.memory_upper_bound { + Some(upper_bound) => upper_bound > SPILL_AT_LEAST_MEMORY, + None => true, + }; // hash map for each agg groups let mut groups = AggHashMap::::with_hasher_in( @@ -508,7 +536,7 @@ impl HashAggExecutor { groups.try_insert(key, agg_states).unwrap(); } - if !self.mem_context.add(memory_usage_diff) { + if !self.mem_context.add(memory_usage_diff) && check_memory { warn!("not enough memory to load one partition agg state after spill which is not a normal case, so keep going"); } } @@ -553,7 +581,7 @@ impl HashAggExecutor { } } // update memory usage - if !self.mem_context.add(memory_usage_diff) { + if !self.mem_context.add(memory_usage_diff) && check_memory { if self.enable_spill { need_to_spill = true; break; @@ -624,26 +652,28 @@ impl HashAggExecutor { // Process each partition one by one. for i in 0..agg_spill_manager.partition_num { + let partition_size = agg_spill_manager.estimate_partition_size(i).await?; + let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?; let input_stream = agg_spill_manager.read_input_partition(i).await?; - let sub_hash_agg_executor: HashAggExecutor = - HashAggExecutor::new_with_init_agg_state( - self.aggs.clone(), - self.group_key_columns.clone(), - self.group_key_types.clone(), + let sub_hash_agg_executor: HashAggExecutor = HashAggExecutor::new_inner( + self.aggs.clone(), + self.group_key_columns.clone(), + self.group_key_types.clone(), + self.schema.clone(), + Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)), + Some(Box::new(WrapStreamExecutor::new( self.schema.clone(), - Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)), - Some(Box::new(WrapStreamExecutor::new( - self.schema.clone(), - agg_state_stream, - ))), - format!("{}-sub{}", self.identity.clone(), i), - self.chunk_size, - self.mem_context.clone(), - self.enable_spill, - self.shutdown_rx.clone(), - ); + agg_state_stream, + ))), + format!("{}-sub{}", self.identity.clone(), i), + self.chunk_size, + self.mem_context.clone(), + self.enable_spill, + Some(partition_size), + self.shutdown_rx.clone(), + ); debug!( "create sub_hash_agg {} for hash_agg {} to spill",