diff --git a/src/batch/src/exchange_source.rs b/src/batch/src/exchange_source.rs index b602b14d5c018..409061594338d 100644 --- a/src/batch/src/exchange_source.rs +++ b/src/batch/src/exchange_source.rs @@ -15,9 +15,10 @@ use std::fmt::Debug; use std::future::Future; +use futures_async_stream::try_stream; use risingwave_common::array::DataChunk; -use crate::error::Result; +use crate::error::{BatchError, Result}; use crate::execution::grpc_exchange::GrpcExchangeSource; use crate::execution::local_exchange::LocalExchangeSource; use crate::executor::test_utils::FakeExchangeSource; @@ -54,4 +55,16 @@ impl ExchangeSourceImpl { ExchangeSourceImpl::Fake(fake) => fake.get_task_id(), } } + + #[try_stream(boxed, ok = DataChunk, error = BatchError)] + pub(crate) async fn take_data_stream(self) { + let mut source = self; + loop { + match source.take_data().await { + Ok(Some(chunk)) => yield chunk, + Ok(None) => break, + Err(e) => return Err(e), + } + } + } } diff --git a/src/batch/src/executor/merge_sort.rs b/src/batch/src/executor/merge_sort.rs new file mode 100644 index 0000000000000..1f5c8f3e5fc2c --- /dev/null +++ b/src/batch/src/executor/merge_sort.rs @@ -0,0 +1,195 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::mem; +use std::sync::Arc; + +use futures_async_stream::try_stream; +use futures_util::StreamExt; +use itertools::Itertools; +use risingwave_common::array::DataChunk; +use risingwave_common::catalog::Schema; +use risingwave_common::memory::{MemMonitoredHeap, MemoryContext, MonitoredGlobalAlloc}; +use risingwave_common::types::ToOwnedDatum; +use risingwave_common::util::sort_util::{ColumnOrder, HeapElem}; +use risingwave_common_estimate_size::EstimateSize; + +use super::{BoxedDataChunkStream, BoxedExecutor, Executor}; +use crate::error::{BatchError, Result}; + +pub struct MergeSortExecutor { + inputs: Vec, + column_orders: Arc>, + identity: String, + schema: Schema, + chunk_size: usize, + mem_context: MemoryContext, + min_heap: MemMonitoredHeap, + current_chunks: Vec, MonitoredGlobalAlloc>, +} + +impl Executor for MergeSortExecutor { + fn schema(&self) -> &Schema { + &self.schema + } + + fn identity(&self) -> &str { + &self.identity + } + + fn execute(self: Box) -> BoxedDataChunkStream { + self.do_execute() + } +} + +impl MergeSortExecutor { + #[try_stream(boxed, ok = DataChunk, error = BatchError)] + async fn do_execute(mut self: Box) { + let mut inputs = vec![]; + mem::swap(&mut inputs, &mut self.inputs); + let mut input_streams = inputs + .into_iter() + .map(|input| input.execute()) + .collect_vec(); + for (input_idx, input_stream) in input_streams.iter_mut().enumerate() { + match input_stream.next().await { + Some(chunk) => { + let chunk = chunk?; + self.current_chunks.push(Some(chunk)); + if let Some(chunk) = &self.current_chunks[input_idx] { + // We assume that we would always get a non-empty chunk from the upstream of + // exchange, therefore we are sure that there is at least + // one visible row. + let next_row_idx = chunk.next_visible_row_idx(0); + self.push_row_into_heap(input_idx, next_row_idx.unwrap()); + } + } + None => { + self.current_chunks.push(None); + } + } + } + + while !self.min_heap.is_empty() { + // It is possible that we cannot produce this much as + // we may run out of input data chunks from sources. + let mut want_to_produce = self.chunk_size; + + let mut builders: Vec<_> = self + .schema + .fields + .iter() + .map(|field| field.data_type.create_array_builder(self.chunk_size)) + .collect(); + let mut array_len = 0; + while want_to_produce > 0 && !self.min_heap.is_empty() { + let top_elem = self.min_heap.pop().unwrap(); + let child_idx = top_elem.chunk_idx(); + let cur_chunk = top_elem.chunk(); + let row_idx = top_elem.elem_idx(); + for (idx, builder) in builders.iter_mut().enumerate() { + let chunk_arr = cur_chunk.column_at(idx); + let chunk_arr = chunk_arr.as_ref(); + let datum = chunk_arr.value_at(row_idx).to_owned_datum(); + builder.append(&datum); + } + want_to_produce -= 1; + array_len += 1; + // check whether we have another row from the same chunk being popped + let possible_next_row_idx = cur_chunk.next_visible_row_idx(row_idx + 1); + match possible_next_row_idx { + Some(next_row_idx) => { + self.push_row_into_heap(child_idx, next_row_idx); + } + None => { + self.get_input_chunk(&mut input_streams, child_idx).await?; + if let Some(chunk) = &self.current_chunks[child_idx] { + let next_row_idx = chunk.next_visible_row_idx(0); + self.push_row_into_heap(child_idx, next_row_idx.unwrap()); + } + } + } + } + + let columns = builders + .into_iter() + .map(|builder| builder.finish().into()) + .collect::>(); + let chunk = DataChunk::new(columns, array_len); + yield chunk + } + } + + async fn get_input_chunk( + &mut self, + input_streams: &mut Vec, + input_idx: usize, + ) -> Result<()> { + assert!(input_idx < input_streams.len()); + let res = input_streams[input_idx].next().await; + let old = match res { + Some(chunk) => { + let chunk = chunk?; + assert_ne!(chunk.cardinality(), 0); + let new_chunk_size = chunk.estimated_heap_size() as i64; + let old = std::mem::replace(&mut self.current_chunks[input_idx], Some(chunk)); + self.mem_context.add(new_chunk_size); + old + } + None => std::mem::take(&mut self.current_chunks[input_idx]), + }; + + if let Some(chunk) = old { + // Reduce the heap size of retired chunk + self.mem_context.add(-(chunk.estimated_heap_size() as i64)); + } + + Ok(()) + } + + fn push_row_into_heap(&mut self, input_idx: usize, row_idx: usize) { + assert!(input_idx < self.current_chunks.len()); + let chunk_ref = self.current_chunks[input_idx].as_ref().unwrap(); + self.min_heap.push(HeapElem::new( + self.column_orders.clone(), + chunk_ref.clone(), + input_idx, + row_idx, + None, + )); + } +} + +impl MergeSortExecutor { + pub fn new( + inputs: Vec, + column_orders: Arc>, + schema: Schema, + identity: String, + chunk_size: usize, + mem_context: MemoryContext, + ) -> Self { + let inputs_num = inputs.len(); + Self { + inputs, + column_orders, + identity, + schema, + chunk_size, + min_heap: MemMonitoredHeap::with_capacity(inputs_num, mem_context.clone()), + current_chunks: Vec::with_capacity_in(inputs_num, mem_context.global_allocator()), + mem_context, + } + } +} diff --git a/src/batch/src/executor/merge_sort_exchange.rs b/src/batch/src/executor/merge_sort_exchange.rs index e2779967dbcbe..3b5647729db25 100644 --- a/src/batch/src/executor/merge_sort_exchange.rs +++ b/src/batch/src/executor/merge_sort_exchange.rs @@ -17,18 +17,15 @@ use std::sync::Arc; use futures_async_stream::try_stream; use risingwave_common::array::DataChunk; use risingwave_common::catalog::{Field, Schema}; -use risingwave_common::memory::{MemMonitoredHeap, MemoryContext, MonitoredGlobalAlloc}; -use risingwave_common::types::ToOwnedDatum; -use risingwave_common::util::sort_util::{ColumnOrder, HeapElem}; -use risingwave_common_estimate_size::EstimateSize; +use risingwave_common::memory::MemoryContext; +use risingwave_common::util::sort_util::ColumnOrder; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::PbExchangeSource; use crate::error::{BatchError, Result}; -use crate::exchange_source::ExchangeSourceImpl; use crate::executor::{ BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, CreateSource, DefaultCreateSource, - Executor, ExecutorBuilder, + Executor, ExecutorBuilder, MergeSortExecutor, WrapStreamExecutor, }; use crate::task::{BatchTaskContext, TaskId}; @@ -38,23 +35,16 @@ pub type MergeSortExchangeExecutor = MergeSortExchangeExecutorImpl { context: C, - /// keeps one data chunk of each source if any - source_inputs: Vec, MonitoredGlobalAlloc>, column_orders: Arc>, - min_heap: MemMonitoredHeap, proto_sources: Vec, - sources: Vec, // impl /// Mock-able `CreateSource`. source_creators: Vec, schema: Schema, - #[expect(dead_code)] task_id: TaskId, identity: String, /// The maximum size of the chunk produced by executor at a time. chunk_size: usize, mem_ctx: MemoryContext, - #[expect(dead_code)] - alloc: MonitoredGlobalAlloc, } impl MergeSortExchangeExecutorImpl { @@ -70,69 +60,18 @@ impl MergeSortExchangeEx chunk_size: usize, ) -> Self { let mem_ctx = context.create_executor_mem_context(&identity); - let alloc = MonitoredGlobalAlloc::with_memory_context(mem_ctx.clone()); - - let source_inputs = { - let mut v = Vec::with_capacity_in(proto_sources.len(), alloc.clone()); - (0..proto_sources.len()).for_each(|_| v.push(None)); - v - }; - - let num_sources = proto_sources.len(); Self { context, - source_inputs, column_orders, - min_heap: MemMonitoredHeap::with_capacity(num_sources, mem_ctx.clone()), proto_sources, - sources: Vec::with_capacity(num_sources), source_creators, schema, task_id, identity, chunk_size, mem_ctx, - alloc, - } - } - - /// We assume that the source would always send `Some(chunk)` with cardinality > 0 - /// or `None`, but never `Some(chunk)` with cardinality == 0. - async fn get_source_chunk(&mut self, source_idx: usize) -> Result<()> { - assert!(source_idx < self.source_inputs.len()); - let res = self.sources[source_idx].take_data().await?; - let old = match res { - Some(chunk) => { - assert_ne!(chunk.cardinality(), 0); - let new_chunk_size = chunk.estimated_heap_size() as i64; - let old = std::mem::replace(&mut self.source_inputs[source_idx], Some(chunk)); - self.mem_ctx.add(new_chunk_size); - old - } - None => std::mem::take(&mut self.source_inputs[source_idx]), - }; - - if let Some(chunk) = old { - // Reduce the heap size of retired chunk - self.mem_ctx.add(-(chunk.estimated_heap_size() as i64)); } - - Ok(()) - } - - // Check whether there is indeed a chunk and there is a visible row sitting at `row_idx` - // in the chunk before calling this function. - fn push_row_into_heap(&mut self, source_idx: usize, row_idx: usize) { - assert!(source_idx < self.source_inputs.len()); - let chunk_ref = self.source_inputs[source_idx].as_ref().unwrap(); - self.min_heap.push(HeapElem::new( - self.column_orders.clone(), - chunk_ref.clone(), - source_idx, - row_idx, - None, - )); } } @@ -156,71 +95,31 @@ impl Executor /// `self.chunk_size` as the executor runs out of input from `sources`. impl MergeSortExchangeExecutorImpl { #[try_stream(boxed, ok = DataChunk, error = BatchError)] - async fn do_execute(mut self: Box) { + async fn do_execute(self: Box) { + let mut sources: Vec = vec![]; for source_idx in 0..self.proto_sources.len() { let new_source = self.source_creators[source_idx] .create_source(self.context.clone(), &self.proto_sources[source_idx]) .await?; - self.sources.push(new_source); - self.get_source_chunk(source_idx).await?; - if let Some(chunk) = &self.source_inputs[source_idx] { - // We assume that we would always get a non-empty chunk from the upstream of - // exchange, therefore we are sure that there is at least - // one visible row. - let next_row_idx = chunk.next_visible_row_idx(0); - self.push_row_into_heap(source_idx, next_row_idx.unwrap()); - } - } - // If there is no rows in the heap, - // we run out of input data chunks and emit `Done`. - while !self.min_heap.is_empty() { - // It is possible that we cannot produce this much as - // we may run out of input data chunks from sources. - let mut want_to_produce = self.chunk_size; + sources.push(Box::new(WrapStreamExecutor::new( + self.schema.clone(), + new_source.take_data_stream(), + ))); + } - let mut builders: Vec<_> = self - .schema() - .fields - .iter() - .map(|field| field.data_type.create_array_builder(self.chunk_size)) - .collect(); - let mut array_len = 0; - while want_to_produce > 0 && !self.min_heap.is_empty() { - let top_elem = self.min_heap.pop().unwrap(); - let child_idx = top_elem.chunk_idx(); - let cur_chunk = top_elem.chunk(); - let row_idx = top_elem.elem_idx(); - for (idx, builder) in builders.iter_mut().enumerate() { - let chunk_arr = cur_chunk.column_at(idx); - let chunk_arr = chunk_arr.as_ref(); - let datum = chunk_arr.value_at(row_idx).to_owned_datum(); - builder.append(&datum); - } - want_to_produce -= 1; - array_len += 1; - // check whether we have another row from the same chunk being popped - let possible_next_row_idx = cur_chunk.next_visible_row_idx(row_idx + 1); - match possible_next_row_idx { - Some(next_row_idx) => { - self.push_row_into_heap(child_idx, next_row_idx); - } - None => { - self.get_source_chunk(child_idx).await?; - if let Some(chunk) = &self.source_inputs[child_idx] { - let next_row_idx = chunk.next_visible_row_idx(0); - self.push_row_into_heap(child_idx, next_row_idx.unwrap()); - } - } - } - } + let merge_sort_executor = Box::new(MergeSortExecutor::new( + sources, + self.column_orders.clone(), + self.schema, + format!("MergeSortExecutor{}", &self.task_id.task_id), + self.chunk_size, + self.mem_ctx, + )); - let columns = builders - .into_iter() - .map(|builder| builder.finish().into()) - .collect::>(); - let chunk = DataChunk::new(columns, array_len); - yield chunk + #[for_await] + for chunk in merge_sort_executor.execute() { + yield chunk?; } } } diff --git a/src/batch/src/executor/mod.rs b/src/batch/src/executor/mod.rs index b77027327fe05..c19bc06c141b9 100644 --- a/src/batch/src/executor/mod.rs +++ b/src/batch/src/executor/mod.rs @@ -27,6 +27,7 @@ mod limit; mod log_row_seq_scan; mod managed; mod max_one_row; +mod merge_sort; mod merge_sort_exchange; mod order_by; mod project; @@ -60,6 +61,7 @@ pub use join::*; pub use limit::*; pub use managed::*; pub use max_one_row::*; +pub use merge_sort::*; pub use merge_sort_exchange::*; pub use order_by::*; pub use project::*;