From c7ad769d6bac8596579687ba99bc2cc85e64c9fb Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 29 May 2024 19:29:07 +0800 Subject: [PATCH] feat(batch): support spill hash agg for the batch query (#16771) --- Cargo.lock | 5 + src/batch/Cargo.toml | 5 + src/batch/benches/hash_agg.rs | 5 +- src/batch/src/error.rs | 7 + src/batch/src/executor/hash_agg.rs | 491 ++++++++++++++++++++++++++--- src/batch/src/executor/utils.rs | 25 ++ src/batch/src/lib.rs | 1 + src/batch/src/spill/mod.rs | 15 + src/batch/src/spill/spill_op.rs | 98 ++++++ src/common/src/config.rs | 8 + src/common/src/hash/key.rs | 2 +- src/config/docs.md | 1 + src/config/example.toml | 1 + src/expr/core/src/aggregate/mod.rs | 3 +- 14 files changed, 619 insertions(+), 48 deletions(-) create mode 100644 src/batch/src/spill/mod.rs create mode 100644 src/batch/src/spill/spill_op.rs diff --git a/Cargo.lock b/Cargo.lock index d0466cb72c1e7..ad387bcc1142e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10087,6 +10087,7 @@ dependencies = [ "assert_matches", "async-recursion", "async-trait", + "bytes", "criterion", "either", "foyer", @@ -10100,9 +10101,11 @@ dependencies = [ "madsim-tokio", "madsim-tonic", "memcomparable", + "opendal", "parking_lot 0.12.1", "paste", "prometheus", + "prost 0.12.1", "rand", "risingwave_common", "risingwave_common_estimate_size", @@ -10125,6 +10128,8 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", + "twox-hash", + "uuid", "workspace-hack", ] diff --git a/src/batch/Cargo.toml b/src/batch/Cargo.toml index 019c33253466b..2ca8ed1be4e77 100644 --- a/src/batch/Cargo.toml +++ b/src/batch/Cargo.toml @@ -20,6 +20,7 @@ arrow-schema = { workspace = true } assert_matches = "1" async-recursion = "1" async-trait = "0.1" +bytes = "1" either = "1" foyer = { workspace = true } futures = { version = "0.3", default-features = false, features = ["alloc"] } @@ -30,9 +31,11 @@ hytra = "0.1.2" icelake = { workspace = true } itertools = { workspace = true } memcomparable = "0.2" +opendal = "0.45.1" parking_lot = { workspace = true } paste = "1" prometheus = { version = "0.13", features = ["process"] } +prost = "0.12" rand = { workspace = true } risingwave_common = { workspace = true } risingwave_common_estimate_size = { workspace = true } @@ -62,6 +65,8 @@ tokio-stream = "0.1" tokio-util = { workspace = true } tonic = { workspace = true } tracing = "0.1" +twox-hash = "1" +uuid = { version = "1", features = ["v4"] } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index f37261f7563e4..e91564692dc95 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -13,6 +13,8 @@ // limitations under the License. pub mod utils; +use std::sync::Arc; + use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use itertools::Itertools; use risingwave_batch::executor::aggregation::build as build_agg; @@ -96,7 +98,7 @@ fn create_hash_agg_executor( let schema = Schema { fields }; Box::new(HashAggExecutor::::new( - agg_init_states, + Arc::new(agg_init_states), group_key_columns, group_key_types, schema, @@ -104,6 +106,7 @@ fn create_hash_agg_executor( "HashAggExecutor".to_string(), CHUNK_SIZE, MemoryContext::none(), + false, ShutdownToken::empty(), )) } diff --git a/src/batch/src/error.rs b/src/batch/src/error.rs index 8033ddfb3479b..27f355aed48b3 100644 --- a/src/batch/src/error.rs +++ b/src/batch/src/error.rs @@ -139,6 +139,13 @@ pub enum BatchError { #[error("Not enough memory to run this query, batch memory limit is {0} bytes")] OutOfMemory(u64), + + #[error("Failed to spill out to disk")] + Spill( + #[from] + #[backtrace] + opendal::Error, + ), } // Serialize/deserialize error. diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index a0e06c958fc59..cb4adcecdc8c7 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -12,27 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::hash::BuildHasher; use std::marker::PhantomData; +use std::sync::Arc; +use anyhow::anyhow; +use bytes::Bytes; use futures_async_stream::try_stream; +use futures_util::AsyncReadExt; use hashbrown::hash_map::Entry; use itertools::Itertools; +use prost::Message; use risingwave_common::array::{DataChunk, StreamChunk}; +use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::hash::{HashKey, HashKeyDispatcher, PrecomputedBuildHasher}; use risingwave_common::memory::MemoryContext; -use risingwave_common::types::DataType; +use risingwave_common::row::{OwnedRow, Row, RowExt}; +use risingwave_common::types::{DataType, ToOwnedDatum}; +use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common_estimate_size::EstimateSize; use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction}; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::HashAggNode; +use risingwave_pb::data::DataChunk as PbDataChunk; +use twox_hash::XxHash64; use crate::error::{BatchError, Result}; use crate::executor::aggregation::build as build_agg; use crate::executor::{ BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder, + WrapStreamExecutor, }; +use crate::spill::spill_op::{SpillOp, DEFAULT_SPILL_PARTITION_NUM}; use crate::task::{BatchTaskContext, ShutdownToken, TaskId}; type AggHashMap = hashbrown::HashMap, PrecomputedBuildHasher, A>; @@ -43,7 +56,7 @@ impl HashKeyDispatcher for HashAggExecutorBuilder { fn dispatch_impl(self) -> Self::Output { Box::new(HashAggExecutor::::new( - self.aggs, + Arc::new(self.aggs), self.group_key_columns, self.group_key_types, self.schema, @@ -51,6 +64,7 @@ impl HashKeyDispatcher for HashAggExecutorBuilder { self.identity, self.chunk_size, self.mem_context, + self.enable_spill, self.shutdown_rx, )) } @@ -70,6 +84,7 @@ pub struct HashAggExecutorBuilder { identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, } @@ -81,6 +96,7 @@ impl HashAggExecutorBuilder { identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, ) -> Result { let aggs: Vec<_> = hash_agg_node @@ -119,6 +135,7 @@ impl HashAggExecutorBuilder { identity, chunk_size, mem_context, + enable_spill, shutdown_rx, }; @@ -148,6 +165,7 @@ impl BoxedExecutorBuilder for HashAggExecutorBuilder { identity.clone(), source.context.get_config().developer.chunk_size, source.context.create_executor_mem_context(identity), + source.context.get_config().enable_spill, source.shutdown_rx.clone(), ) } @@ -156,7 +174,7 @@ impl BoxedExecutorBuilder for HashAggExecutorBuilder { /// `HashAggExecutor` implements the hash aggregate algorithm. pub struct HashAggExecutor { /// Aggregate functions. - aggs: Vec, + aggs: Arc>, /// Column indexes that specify a group group_key_columns: Vec, /// Data types of group key columns @@ -164,16 +182,19 @@ pub struct HashAggExecutor { /// Output schema schema: Schema, child: BoxedExecutor, + /// Used to initialize the state of the aggregation from the spilled files. + init_agg_state_executor: Option, identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, _phantom: PhantomData, } impl HashAggExecutor { pub fn new( - aggs: Vec, + aggs: Arc>, group_key_columns: Vec, group_key_types: Vec, schema: Schema, @@ -181,6 +202,36 @@ impl HashAggExecutor { identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, + shutdown_rx: ShutdownToken, + ) -> Self { + Self::new_with_init_agg_state( + aggs, + group_key_columns, + group_key_types, + schema, + child, + None, + identity, + chunk_size, + mem_context, + enable_spill, + shutdown_rx, + ) + } + + #[allow(clippy::too_many_arguments)] + fn new_with_init_agg_state( + aggs: Arc>, + group_key_columns: Vec, + group_key_types: Vec, + schema: Schema, + child: BoxedExecutor, + init_agg_state_executor: Option, + identity: String, + chunk_size: usize, + mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, ) -> Self { HashAggExecutor { @@ -189,9 +240,11 @@ impl HashAggExecutor { group_key_types, schema, child, + init_agg_state_executor, identity, chunk_size, mem_context, + enable_spill, shutdown_rx, _phantom: PhantomData, } @@ -212,18 +265,259 @@ impl Executor for HashAggExecutor { } } +#[derive(Default, Clone, Copy)] +pub struct SpillBuildHasher(u64); + +impl BuildHasher for SpillBuildHasher { + type Hasher = XxHash64; + + fn build_hasher(&self) -> Self::Hasher { + XxHash64::with_seed(self.0) + } +} + +/// `AggSpillManager` is used to manage how to write spill data file and read them back. +/// The spill data first need to be partitioned. Each partition contains 2 files: `agg_state_file` and `input_chunks_file`. +/// The spill file consume a data chunk and serialize the chunk into a protobuf bytes. +/// Finally, spill file content will look like the below. +/// The file write pattern is append-only and the read pattern is sequential scan. +/// This can maximize the disk IO performance. +/// +/// ```text +/// [proto_len] +/// [proto_bytes] +/// ... +/// [proto_len] +/// [proto_bytes] +/// ``` +pub struct AggSpillManager { + op: SpillOp, + partition_num: usize, + agg_state_writers: Vec, + agg_state_readers: Vec, + agg_state_chunk_builder: Vec, + input_writers: Vec, + input_readers: Vec, + input_chunk_builders: Vec, + spill_build_hasher: SpillBuildHasher, + group_key_types: Vec, + child_data_types: Vec, + agg_data_types: Vec, + spill_chunk_size: usize, +} + +impl AggSpillManager { + fn new( + agg_identity: &String, + partition_num: usize, + group_key_types: Vec, + agg_data_types: Vec, + child_data_types: Vec, + spill_chunk_size: usize, + ) -> Result { + let suffix_uuid = uuid::Uuid::new_v4(); + let dir = format!("/{}-{}/", agg_identity, suffix_uuid); + let op = SpillOp::create(dir)?; + let agg_state_writers = Vec::with_capacity(partition_num); + let agg_state_readers = Vec::with_capacity(partition_num); + let agg_state_chunk_builder = Vec::with_capacity(partition_num); + let input_writers = Vec::with_capacity(partition_num); + let input_readers = Vec::with_capacity(partition_num); + let input_chunk_builders = Vec::with_capacity(partition_num); + // Use uuid to generate an unique hasher so that when recursive spilling happens they would use a different hasher to avoid data skew. + let spill_build_hasher = SpillBuildHasher(suffix_uuid.as_u64_pair().1); + Ok(Self { + op, + partition_num, + agg_state_writers, + agg_state_readers, + agg_state_chunk_builder, + input_writers, + input_readers, + input_chunk_builders, + spill_build_hasher, + group_key_types, + child_data_types, + agg_data_types, + spill_chunk_size, + }) + } + + async fn init_writers(&mut self) -> Result<()> { + for i in 0..self.partition_num { + let agg_state_partition_file_name = format!("agg-state-p{}", i); + let w = self.op.writer_with(&agg_state_partition_file_name).await?; + self.agg_state_writers.push(w); + + let partition_file_name = format!("input-chunks-p{}", i); + let w = self.op.writer_with(&partition_file_name).await?; + self.input_writers.push(w); + self.input_chunk_builders.push(DataChunkBuilder::new( + self.child_data_types.clone(), + self.spill_chunk_size, + )); + self.agg_state_chunk_builder.push(DataChunkBuilder::new( + self.group_key_types + .iter() + .cloned() + .chain(self.agg_data_types.iter().cloned()) + .collect(), + self.spill_chunk_size, + )); + } + Ok(()) + } + + async fn write_agg_state_row(&mut self, row: impl Row, hash_code: u64) -> Result<()> { + let partition = hash_code as usize % self.partition_num; + if let Some(output_chunk) = self.agg_state_chunk_builder[partition].append_one_row(row) { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.agg_state_writers[partition].write(len_bytes).await?; + self.agg_state_writers[partition].write(buf).await?; + } + Ok(()) + } + + async fn write_input_chunk(&mut self, chunk: DataChunk, hash_codes: Vec) -> Result<()> { + let (columns, vis) = chunk.into_parts_v2(); + for partition in 0..self.partition_num { + let new_vis = vis.clone() + & Bitmap::from_iter( + hash_codes + .iter() + .map(|hash_code| (*hash_code as usize % self.partition_num) == partition), + ); + let new_chunk = DataChunk::from_parts(columns.clone(), new_vis); + for output_chunk in self.input_chunk_builders[partition].append_chunk(new_chunk) { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.input_writers[partition].write(len_bytes).await?; + self.input_writers[partition].write(buf).await?; + } + } + Ok(()) + } + + async fn close_writers(&mut self) -> Result<()> { + for partition in 0..self.partition_num { + if let Some(output_chunk) = self.agg_state_chunk_builder[partition].consume_all() { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.agg_state_writers[partition].write(len_bytes).await?; + self.agg_state_writers[partition].write(buf).await?; + } + + if let Some(output_chunk) = self.input_chunk_builders[partition].consume_all() { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.input_writers[partition].write(len_bytes).await?; + self.input_writers[partition].write(buf).await?; + } + } + + for mut w in self.agg_state_writers.drain(..) { + w.close().await?; + } + for mut w in self.input_writers.drain(..) { + w.close().await?; + } + Ok(()) + } + + #[try_stream(boxed, ok = DataChunk, error = BatchError)] + async fn read_stream(mut reader: opendal::Reader) { + let mut buf = [0u8; 4]; + loop { + if let Err(err) = reader.read_exact(&mut buf).await { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + break; + } else { + return Err(anyhow!(err).into()); + } + } + let len = u32::from_le_bytes(buf) as usize; + let mut buf = vec![0u8; len]; + reader.read_exact(&mut buf).await.map_err(|e| anyhow!(e))?; + let chunk_pb: PbDataChunk = Message::decode(buf.as_slice()).map_err(|e| anyhow!(e))?; + let chunk = DataChunk::from_protobuf(&chunk_pb)?; + yield chunk; + } + } + + async fn read_agg_state_partition(&mut self, partition: usize) -> Result { + let agg_state_partition_file_name = format!("agg-state-p{}", partition); + let r = self.op.reader_with(&agg_state_partition_file_name).await?; + Ok(Self::read_stream(r)) + } + + async fn read_input_partition(&mut self, partition: usize) -> Result { + let input_partition_file_name = format!("input-chunks-p{}", partition); + let r = self.op.reader_with(&input_partition_file_name).await?; + Ok(Self::read_stream(r)) + } + + async fn clear_partition(&mut self, partition: usize) -> Result<()> { + let agg_state_partition_file_name = format!("agg-state-p{}", partition); + self.op.delete(&agg_state_partition_file_name).await?; + let input_partition_file_name = format!("input-chunks-p{}", partition); + self.op.delete(&input_partition_file_name).await?; + Ok(()) + } +} + impl HashAggExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_execute(self: Box) { + let child_schema = self.child.schema().clone(); + let mut need_to_spill = false; + // hash map for each agg groups let mut groups = AggHashMap::::with_hasher_in( PrecomputedBuildHasher, self.mem_context.global_allocator(), ); + if let Some(init_agg_state_executor) = self.init_agg_state_executor { + // `init_agg_state_executor` exists which means this is a sub `HashAggExecutor` used to consume spilling data. + // The spilled agg states by its parent executor need to be recovered first. + let mut init_agg_state_stream = init_agg_state_executor.execute(); + #[for_await] + for chunk in &mut init_agg_state_stream { + let chunk = chunk?; + let group_key_indices = (0..self.group_key_columns.len()).collect_vec(); + let keys = K::build_many(&group_key_indices, &chunk); + let mut memory_usage_diff = 0; + for (row_id, key) in keys.into_iter().enumerate() { + let mut agg_states = vec![]; + for i in 0..self.aggs.len() { + let agg = &self.aggs[i]; + let datum = chunk + .row_at(row_id) + .0 + .datum_at(self.group_key_columns.len() + i) + .to_owned_datum(); + let agg_state = agg.decode_state(datum)?; + memory_usage_diff += agg_state.estimated_size() as i64; + agg_states.push(agg_state); + } + groups.try_insert(key, agg_states).unwrap(); + } + + if !self.mem_context.add(memory_usage_diff) { + warn!("not enough memory to load one partition agg state after spill which is not a normal case, so keep going"); + } + } + } + + let mut input_stream = self.child.execute(); // consume all chunks to compute the agg result #[for_await] - for chunk in self.child.execute() { + for chunk in &mut input_stream { let chunk = StreamChunk::from(chunk?); let keys = K::build_many(self.group_key_columns.as_slice(), &chunk); let mut memory_usage_diff = 0; @@ -260,53 +554,158 @@ impl HashAggExecutor { } // update memory usage if !self.mem_context.add(memory_usage_diff) { - Err(BatchError::OutOfMemory(self.mem_context.mem_limit()))?; + if self.enable_spill { + need_to_spill = true; + break; + } else { + Err(BatchError::OutOfMemory(self.mem_context.mem_limit()))?; + } } } - // Don't use `into_iter` here, it may cause memory leak. - let mut result = groups.iter_mut(); - let cardinality = self.chunk_size; - loop { - let mut group_builders: Vec<_> = self - .group_key_types - .iter() - .map(|datatype| datatype.create_array_builder(cardinality)) - .collect(); - - let mut agg_builders: Vec<_> = self - .aggs - .iter() - .map(|agg| agg.return_type().create_array_builder(cardinality)) - .collect(); - - let mut has_next = false; - let mut array_len = 0; - for (key, states) in result.by_ref().take(cardinality) { - self.shutdown_rx.check()?; - has_next = true; - array_len += 1; - key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?; - for ((agg, state), builder) in (self.aggs.iter()) - .zip_eq_fast(states) - .zip_eq_fast(&mut agg_builders) - { - let result = agg.get_result(state).await?; - builder.append(result); + if need_to_spill { + // A spilling version of aggregation based on the RFC: Spill Hash Aggregation https://github.com/risingwavelabs/rfcs/pull/89 + // When HashAggExecutor told memory is insufficient, AggSpillManager will start to partition the hash table and spill to disk. + // After spilling the hash table, AggSpillManager will consume all chunks from the input executor, + // partition and spill to disk with the same hash function as the hash table spilling. + // Finally, we would get e.g. 20 partitions. Each partition should contain a portion of the original hash table and input data. + // A sub HashAggExecutor would be used to consume each partition one by one. + // If memory is still not enough in the sub HashAggExecutor, it will partition its hash table and input recursively. + let mut agg_spill_manager = AggSpillManager::new( + &self.identity, + DEFAULT_SPILL_PARTITION_NUM, + self.group_key_types.clone(), + self.aggs.iter().map(|agg| agg.return_type()).collect(), + child_schema.data_types(), + self.chunk_size, + )?; + agg_spill_manager.init_writers().await?; + + let mut memory_usage_diff = 0; + // Spill agg states. + for (key, states) in groups { + let key_row = key.deserialize(&self.group_key_types)?; + let mut agg_datums = vec![]; + for (agg, state) in self.aggs.iter().zip_eq_fast(states) { + let encode_state = agg.encode_state(&state)?; + memory_usage_diff -= state.estimated_size() as i64; + agg_datums.push(encode_state); } + let agg_state_row = OwnedRow::from_iter(agg_datums.into_iter()); + let hash_code = agg_spill_manager.spill_build_hasher.hash_one(key); + agg_spill_manager + .write_agg_state_row(key_row.chain(agg_state_row), hash_code) + .await?; } - if !has_next { - break; // exit loop + + // Release memory occupied by agg hash map. + self.mem_context.add(memory_usage_diff); + + // Spill input chunks. + #[for_await] + for chunk in input_stream { + let chunk: DataChunk = chunk?; + let hash_codes = chunk.get_hash_values( + self.group_key_columns.as_slice(), + agg_spill_manager.spill_build_hasher, + ); + agg_spill_manager + .write_input_chunk( + chunk, + hash_codes + .into_iter() + .map(|hash_code| hash_code.value()) + .collect(), + ) + .await?; } - let columns = group_builders - .into_iter() - .chain(agg_builders) - .map(|b| b.finish().into()) - .collect::>(); + agg_spill_manager.close_writers().await?; + + // Process each partition one by one. + for i in 0..agg_spill_manager.partition_num { + let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?; + let input_stream = agg_spill_manager.read_input_partition(i).await?; + + let sub_hash_agg_executor: HashAggExecutor = + HashAggExecutor::new_with_init_agg_state( + self.aggs.clone(), + self.group_key_columns.clone(), + self.group_key_types.clone(), + self.schema.clone(), + Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)), + Some(Box::new(WrapStreamExecutor::new( + self.schema.clone(), + agg_state_stream, + ))), + format!("{}-sub{}", self.identity.clone(), i), + self.chunk_size, + self.mem_context.clone(), + self.enable_spill, + self.shutdown_rx.clone(), + ); + + debug!( + "create sub_hash_agg {} for hash_agg {} to spill", + sub_hash_agg_executor.identity, self.identity + ); + + let sub_hash_agg_stream = Box::new(sub_hash_agg_executor).execute(); + + #[for_await] + for chunk in sub_hash_agg_stream { + let chunk = chunk?; + yield chunk; + } - let output = DataChunk::new(columns, array_len); - yield output; + // Clear files of the current partition. + agg_spill_manager.clear_partition(i).await?; + } + } else { + // Don't use `into_iter` here, it may cause memory leak. + let mut result = groups.iter_mut(); + let cardinality = self.chunk_size; + loop { + let mut group_builders: Vec<_> = self + .group_key_types + .iter() + .map(|datatype| datatype.create_array_builder(cardinality)) + .collect(); + + let mut agg_builders: Vec<_> = self + .aggs + .iter() + .map(|agg| agg.return_type().create_array_builder(cardinality)) + .collect(); + + let mut has_next = false; + let mut array_len = 0; + for (key, states) in result.by_ref().take(cardinality) { + self.shutdown_rx.check()?; + has_next = true; + array_len += 1; + key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?; + for ((agg, state), builder) in (self.aggs.iter()) + .zip_eq_fast(states) + .zip_eq_fast(&mut agg_builders) + { + let result = agg.get_result(state).await?; + builder.append(result); + } + } + if !has_next { + break; // exit loop + } + + let columns = group_builders + .into_iter() + .chain(agg_builders) + .map(|b| b.finish().into()) + .collect::>(); + + let output = DataChunk::new(columns, array_len); + yield output; + } } } } @@ -316,7 +715,6 @@ mod tests { use std::alloc::{AllocError, Allocator, Global, Layout}; use std::ptr::NonNull; use std::sync::atomic::{AtomicBool, Ordering}; - use std::sync::Arc; use futures_async_stream::for_await; use risingwave_common::metrics::LabelGuardedIntGauge; @@ -390,6 +788,7 @@ mod tests { "HashAggExecutor".to_string(), CHUNK_SIZE, mem_context.clone(), + false, ShutdownToken::empty(), ) .unwrap(); @@ -462,6 +861,7 @@ mod tests { "HashAggExecutor".to_string(), CHUNK_SIZE, MemoryContext::none(), + false, ShutdownToken::empty(), ) .unwrap(); @@ -577,6 +977,7 @@ mod tests { "HashAggExecutor".to_string(), CHUNK_SIZE, MemoryContext::none(), + false, shutdown_rx, ) .unwrap(); diff --git a/src/batch/src/executor/utils.rs b/src/batch/src/executor/utils.rs index 9c6f162f02268..4f724ec5416c8 100644 --- a/src/batch/src/executor/utils.rs +++ b/src/batch/src/executor/utils.rs @@ -99,3 +99,28 @@ impl DummyExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_nothing() {} } + +pub struct WrapStreamExecutor { + schema: Schema, + stream: BoxedDataChunkStream, +} + +impl WrapStreamExecutor { + pub fn new(schema: Schema, stream: BoxedDataChunkStream) -> Self { + Self { schema, stream } + } +} + +impl Executor for WrapStreamExecutor { + fn schema(&self) -> &Schema { + &self.schema + } + + fn identity(&self) -> &str { + "WrapStreamExecutor" + } + + fn execute(self: Box) -> BoxedDataChunkStream { + self.stream + } +} diff --git a/src/batch/src/lib.rs b/src/batch/src/lib.rs index b8e6df1ac9538..2c072319e1c8a 100644 --- a/src/batch/src/lib.rs +++ b/src/batch/src/lib.rs @@ -39,6 +39,7 @@ pub mod execution; pub mod executor; pub mod monitor; pub mod rpc; +mod spill; pub mod task; pub mod worker_manager; diff --git a/src/batch/src/spill/mod.rs b/src/batch/src/spill/mod.rs new file mode 100644 index 0000000000000..6af1eae7429c6 --- /dev/null +++ b/src/batch/src/spill/mod.rs @@ -0,0 +1,15 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod spill_op; diff --git a/src/batch/src/spill/spill_op.rs b/src/batch/src/spill/spill_op.rs new file mode 100644 index 0000000000000..115a0c2d430e1 --- /dev/null +++ b/src/batch/src/spill/spill_op.rs @@ -0,0 +1,98 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::{Deref, DerefMut}; + +use opendal::layers::RetryLayer; +use opendal::services::Fs; +use opendal::Operator; +use thiserror_ext::AsReport; + +use crate::error::Result; + +const RW_BATCH_SPILL_DIR_ENV: &str = "RW_BATCH_SPILL_DIR"; +pub const DEFAULT_SPILL_PARTITION_NUM: usize = 20; +const DEFAULT_SPILL_DIR: &str = "/tmp/"; +const RW_MANAGED_SPILL_DIR: &str = "/rw_batch_spill/"; +const DEFAULT_IO_BUFFER_SIZE: usize = 256 * 1024; +const DEFAULT_IO_CONCURRENT_TASK: usize = 8; + +/// `SpillOp` is used to manage the spill directory of the spilling executor and it will drop the directory with a RAII style. +pub struct SpillOp { + pub op: Operator, +} + +impl SpillOp { + pub fn create(path: String) -> Result { + assert!(path.ends_with('/')); + + let spill_dir = + std::env::var(RW_BATCH_SPILL_DIR_ENV).unwrap_or_else(|_| DEFAULT_SPILL_DIR.to_string()); + let root = format!("/{}/{}/{}/", spill_dir, RW_MANAGED_SPILL_DIR, path); + + let mut builder = Fs::default(); + builder.root(&root); + + let op: Operator = Operator::new(builder)? + .layer(RetryLayer::default()) + .finish(); + Ok(SpillOp { op }) + } + + pub async fn writer_with(&self, name: &str) -> Result { + Ok(self + .op + .writer_with(name) + .buffer(DEFAULT_IO_BUFFER_SIZE) + .concurrent(DEFAULT_IO_CONCURRENT_TASK) + .await?) + } + + pub async fn reader_with(&self, name: &str) -> Result { + Ok(self + .op + .reader_with(name) + .buffer(DEFAULT_IO_BUFFER_SIZE) + .await?) + } +} + +impl Drop for SpillOp { + fn drop(&mut self) { + let op = self.op.clone(); + tokio::task::spawn(async move { + let result = op.remove_all("/").await; + if let Err(error) = result { + error!( + error = %error.as_report(), + "Failed to remove spill directory" + ); + } + }); + } +} + +impl DerefMut for SpillOp { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.op + } +} + +impl Deref for SpillOp { + type Target = Operator; + + fn deref(&self) -> &Self::Target { + &self.op + } +} diff --git a/src/common/src/config.rs b/src/common/src/config.rs index 8ae14702d3261..c8da0f6dce5e9 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -535,6 +535,10 @@ pub struct BatchConfig { /// A SQL option with a name containing any of these keywords will be redacted. #[serde(default = "default::batch::redact_sql_option_keywords")] pub redact_sql_option_keywords: Vec, + + /// Enable the spill out to disk feature for batch queries. + #[serde(default = "default::batch::enable_spill")] + pub enable_spill: bool, } /// The section `[streaming]` in `risingwave.toml`. @@ -1759,6 +1763,10 @@ pub mod default { false } + pub fn enable_spill() -> bool { + true + } + pub fn statement_timeout_in_sec() -> u32 { // 1 hour 60 * 60 diff --git a/src/common/src/hash/key.rs b/src/common/src/hash/key.rs index 4911e041370a9..c7e57173a3e74 100644 --- a/src/common/src/hash/key.rs +++ b/src/common/src/hash/key.rs @@ -236,7 +236,7 @@ impl From for HashCode { } impl HashCode { - pub fn value(self) -> u64 { + pub fn value(&self) -> u64 { self.value } } diff --git a/src/config/docs.md b/src/config/docs.md index 2f8c4ce2812b1..018c9dd41087c 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -8,6 +8,7 @@ This page is automatically generated by `./risedev generate-example-config` |--------|-------------|---------| | distributed_query_limit | This is the max number of queries per sql session. | | | enable_barrier_read | | false | +| enable_spill | Enable the spill out to disk feature for batch queries. | true | | frontend_compute_runtime_worker_threads | frontend compute runtime worker threads | 4 | | mask_worker_temporary_secs | This is the secs used to mask a worker unavailable temporarily. | 30 | | max_batch_queries_per_frontend_node | This is the max number of batch queries per frontend node. | | diff --git a/src/config/example.toml b/src/config/example.toml index fb2243535d6a4..00b1ef759e5f9 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -87,6 +87,7 @@ statement_timeout_in_sec = 3600 frontend_compute_runtime_worker_threads = 4 mask_worker_temporary_secs = 30 redact_sql_option_keywords = ["credential", "key", "password", "private", "secret", "token"] +enable_spill = true [batch.developer] batch_connector_message_buffer_size = 16 diff --git a/src/expr/core/src/aggregate/mod.rs b/src/expr/core/src/aggregate/mod.rs index 8ccd71b0e3e37..2a1119d6fe301 100644 --- a/src/expr/core/src/aggregate/mod.rs +++ b/src/expr/core/src/aggregate/mod.rs @@ -15,6 +15,7 @@ use std::fmt::Debug; use std::ops::Range; +use anyhow::anyhow; use downcast_rs::{impl_downcast, Downcast}; use itertools::Itertools; use risingwave_common::array::StreamChunk; @@ -60,7 +61,7 @@ pub trait AggregateFunction: Send + Sync + 'static { fn encode_state(&self, state: &AggregateState) -> Result { match state { AggregateState::Datum(d) => Ok(d.clone()), - _ => panic!("cannot encode state"), + AggregateState::Any(_) => Err(ExprError::Internal(anyhow!("cannot encode state"))), } }