Skip to content

Commit

Permalink
feat(batch): add spill at least memory for hash agg (#17021)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzl25 authored May 31, 2024
1 parent 9edfd72 commit 1c1f349
Showing 1 changed file with 50 additions and 20 deletions.
70 changes: 50 additions & 20 deletions src/batch/src/executor/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ pub struct HashAggExecutor<K> {
chunk_size: usize,
mem_context: MemoryContext,
enable_spill: bool,
/// The upper bound of memory usage for this executor.
memory_upper_bound: Option<u64>,
shutdown_rx: ShutdownToken,
_phantom: PhantomData<K>,
}
Expand All @@ -205,7 +207,7 @@ impl<K> HashAggExecutor<K> {
enable_spill: bool,
shutdown_rx: ShutdownToken,
) -> Self {
Self::new_with_init_agg_state(
Self::new_inner(
aggs,
group_key_columns,
group_key_types,
Expand All @@ -216,12 +218,13 @@ impl<K> HashAggExecutor<K> {
chunk_size,
mem_context,
enable_spill,
None,
shutdown_rx,
)
}

#[allow(clippy::too_many_arguments)]
fn new_with_init_agg_state(
fn new_inner(
aggs: Arc<Vec<BoxedAggregateFunction>>,
group_key_columns: Vec<usize>,
group_key_types: Vec<DataType>,
Expand All @@ -232,6 +235,7 @@ impl<K> HashAggExecutor<K> {
chunk_size: usize,
mem_context: MemoryContext,
enable_spill: bool,
memory_upper_bound: Option<u64>,
shutdown_rx: ShutdownToken,
) -> Self {
HashAggExecutor {
Expand All @@ -245,6 +249,7 @@ impl<K> HashAggExecutor<K> {
chunk_size,
mem_context,
enable_spill,
memory_upper_bound,
shutdown_rx,
_phantom: PhantomData,
}
Expand Down Expand Up @@ -461,6 +466,22 @@ impl AggSpillManager {
Ok(Self::read_stream(r))
}

async fn estimate_partition_size(&self, partition: usize) -> Result<u64> {
let agg_state_partition_file_name = format!("agg-state-p{}", partition);
let agg_state_size = self
.op
.stat(&agg_state_partition_file_name)
.await?
.content_length();
let input_partition_file_name = format!("input-chunks-p{}", partition);
let input_size = self
.op
.stat(&input_partition_file_name)
.await?
.content_length();
Ok(agg_state_size + input_size)
}

async fn clear_partition(&mut self, partition: usize) -> Result<()> {
let agg_state_partition_file_name = format!("agg-state-p{}", partition);
self.op.delete(&agg_state_partition_file_name).await?;
Expand All @@ -470,11 +491,18 @@ impl AggSpillManager {
}
}

const SPILL_AT_LEAST_MEMORY: u64 = 1024 * 1024;

impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
#[try_stream(boxed, ok = DataChunk, error = BatchError)]
async fn do_execute(self: Box<Self>) {
let child_schema = self.child.schema().clone();
let mut need_to_spill = false;
// If the memory upper bound is less than 1MB, we don't need to check memory usage.
let check_memory = match self.memory_upper_bound {
Some(upper_bound) => upper_bound > SPILL_AT_LEAST_MEMORY,
None => true,
};

// hash map for each agg groups
let mut groups = AggHashMap::<K, _>::with_hasher_in(
Expand Down Expand Up @@ -508,7 +536,7 @@ impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
groups.try_insert(key, agg_states).unwrap();
}

if !self.mem_context.add(memory_usage_diff) {
if !self.mem_context.add(memory_usage_diff) && check_memory {
warn!("not enough memory to load one partition agg state after spill which is not a normal case, so keep going");
}
}
Expand Down Expand Up @@ -553,7 +581,7 @@ impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
}
}
// update memory usage
if !self.mem_context.add(memory_usage_diff) {
if !self.mem_context.add(memory_usage_diff) && check_memory {
if self.enable_spill {
need_to_spill = true;
break;
Expand Down Expand Up @@ -624,26 +652,28 @@ impl<K: HashKey + Send + Sync> HashAggExecutor<K> {

// Process each partition one by one.
for i in 0..agg_spill_manager.partition_num {
let partition_size = agg_spill_manager.estimate_partition_size(i).await?;

let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?;
let input_stream = agg_spill_manager.read_input_partition(i).await?;

let sub_hash_agg_executor: HashAggExecutor<K> =
HashAggExecutor::new_with_init_agg_state(
self.aggs.clone(),
self.group_key_columns.clone(),
self.group_key_types.clone(),
let sub_hash_agg_executor: HashAggExecutor<K> = HashAggExecutor::new_inner(
self.aggs.clone(),
self.group_key_columns.clone(),
self.group_key_types.clone(),
self.schema.clone(),
Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)),
Some(Box::new(WrapStreamExecutor::new(
self.schema.clone(),
Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)),
Some(Box::new(WrapStreamExecutor::new(
self.schema.clone(),
agg_state_stream,
))),
format!("{}-sub{}", self.identity.clone(), i),
self.chunk_size,
self.mem_context.clone(),
self.enable_spill,
self.shutdown_rx.clone(),
);
agg_state_stream,
))),
format!("{}-sub{}", self.identity.clone(), i),
self.chunk_size,
self.mem_context.clone(),
self.enable_spill,
Some(partition_size),
self.shutdown_rx.clone(),
);

debug!(
"create sub_hash_agg {} for hash_agg {} to spill",
Expand Down

0 comments on commit 1c1f349

Please sign in to comment.