From 2cdbb117c10fab12890b006232690be771fbf6e4 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Thu, 25 Aug 2022 19:42:48 +0800 Subject: [PATCH] feat(stream,agg): support `array_agg` for streaming mode (#4895) * remove a outdated todo Signed-off-by: Richard Chien * support streaming array_agg Signed-off-by: Richard Chien * add unittest Signed-off-by: Richard Chien * considering NULL value Signed-off-by: Richard Chien * add e2e test for streaming array_agg Signed-off-by: Richard Chien Signed-off-by: Richard Chien Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- e2e_test/streaming/array_agg.slt | 57 ++ .../src/optimizer/plan_node/logical_agg.rs | 1 - .../managed_state/aggregation/array_agg.rs | 568 ++++++++++++++++++ .../executor/managed_state/aggregation/mod.rs | 16 +- 4 files changed, 635 insertions(+), 7 deletions(-) create mode 100644 e2e_test/streaming/array_agg.slt create mode 100644 src/stream/src/executor/managed_state/aggregation/array_agg.rs diff --git a/e2e_test/streaming/array_agg.slt b/e2e_test/streaming/array_agg.slt new file mode 100644 index 0000000000000..42c02a05c47ab --- /dev/null +++ b/e2e_test/streaming/array_agg.slt @@ -0,0 +1,57 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t (a varchar, b int, c int); + +statement ok +insert into t values ('aaa', 1, 1), ('bbb', 0, 2), ('ccc', 0, 5), ('ddd', 1, 4); + +statement ok +create materialized view mv1 as select array_agg(c) as res from t; + +statement ok +create materialized view mv2 as select array_agg(a order by b asc, a desc) as res from t; + +statement ok +flush; + +query T +select u from (select unnest(res) from mv1) p(u) order by u; +---- +1 +2 +4 +5 + +query T +select * from mv2; +---- +{ccc,bbb,ddd,aaa} + +statement ok +insert into t values ('x', 1, 2), ('y', 3, 6); + +query T +select u from (select unnest(res) from mv1) p(u) order by u; +---- +1 +2 +2 +4 +5 +6 + +query T +select * from mv2; +---- +{ccc,bbb,x,ddd,aaa,y} + +statement ok +drop materialized view mv1; + +statement ok +drop materialized view mv2; + +statement ok +drop table t; diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 677c7f40cd1a9..4800112932566 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -451,7 +451,6 @@ impl LogicalAgg { table_catalogs.push(state_table); column_mappings_vec.push(column_mapping); } - // TODO: fill column mapping later (#3485). (table_catalogs, column_mappings_vec) } diff --git a/src/stream/src/executor/managed_state/aggregation/array_agg.rs b/src/stream/src/executor/managed_state/aggregation/array_agg.rs new file mode 100644 index 0000000000000..c8ea1779fe5e0 --- /dev/null +++ b/src/stream/src/executor/managed_state/aggregation/array_agg.rs @@ -0,0 +1,568 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::marker::PhantomData; +use std::sync::Arc; + +use async_trait::async_trait; +use futures::pin_mut; +use futures_async_stream::for_await; +use risingwave_common::array::stream_chunk::Ops; +use risingwave_common::array::Op::{Delete, Insert, UpdateDelete, UpdateInsert}; +use risingwave_common::array::{ArrayImpl, ListValue, Row}; +use risingwave_common::buffer::Bitmap; +use risingwave_common::types::Datum; +use risingwave_common::util::ordered::OrderedRow; +use risingwave_common::util::sort_util::OrderType; +use risingwave_storage::table::streaming_table::state_table::StateTable; +use risingwave_storage::StateStore; + +use super::{Cache, ManagedTableState}; +use crate::common::StateTableColumnMapping; +use crate::executor::aggregation::AggCall; +use crate::executor::error::StreamExecutorResult; +use crate::executor::managed_state::iter_state_table; +use crate::executor::PkIndices; + +pub struct ManagedArrayAggState { + _phantom_data: PhantomData, + + /// Group key to aggregate with group. + /// None for simple agg, Some for group key of hash agg. + group_key: Option, + + /// Contains the column mapping between upstream schema and state table. + state_table_col_mapping: Arc, + + /// The column to aggregate in state table. + state_table_agg_col_idx: usize, + + /// The columns to order by in state table. + state_table_order_col_indices: Vec, + + /// The order types of `state_table_order_col_indices`. + state_table_order_types: Vec, + + /// In-memory all-or-nothing cache. + cache: Cache, + + /// Whether the cache is fully synced to state table. + cache_synced: bool, +} + +impl ManagedArrayAggState { + pub fn new( + agg_call: AggCall, + group_key: Option<&Row>, + pk_indices: PkIndices, + col_mapping: Arc, + row_count: usize, + ) -> Self { + // map agg column to state table column index + let state_table_agg_col_idx = col_mapping + .upstream_to_state_table(agg_call.args.val_indices()[0]) + .expect("the column to be aggregate must appear in the state table"); + // map order by columns to state table column indices + let (state_table_order_col_indices, state_table_order_types) = agg_call + .order_pairs + .iter() + .map(|o| { + ( + col_mapping + .upstream_to_state_table(o.column_idx) + .expect("the column to be order by must appear in the state table"), + o.order_type, + ) + }) + .chain(pk_indices.iter().map(|idx| { + ( + col_mapping + .upstream_to_state_table(*idx) + .expect("the pk columns must appear in the state table"), + OrderType::Ascending, + ) + })) + .unzip(); + Self { + _phantom_data: PhantomData, + group_key: group_key.cloned(), + state_table_col_mapping: col_mapping, + state_table_agg_col_idx, + state_table_order_col_indices, + state_table_order_types, + cache: Cache::new(usize::MAX), + cache_synced: row_count == 0, // if there is no row, the cache is synced initially + } + } + + fn state_row_to_cache_entry(&self, state_row: &Row) -> (OrderedRow, Datum) { + let cache_key = OrderedRow::new( + state_row.by_indices(&self.state_table_order_col_indices), + &self.state_table_order_types, + ); + let cache_data = state_row[self.state_table_agg_col_idx].clone(); + (cache_key, cache_data) + } + + fn apply_chunk_inner( + &mut self, + ops: Ops<'_>, + visibility: Option<&Bitmap>, + columns: &[&ArrayImpl], + state_table: &mut StateTable, + ) -> StreamExecutorResult<()> { + debug_assert!(super::verify_batch(ops, visibility, columns)); + + for (i, op) in ops.iter().enumerate() { + let visible = visibility.map(|x| x.is_set(i).unwrap()).unwrap_or(true); + if !visible { + continue; + } + + let state_row = Row::new( + self.state_table_col_mapping + .upstream_columns() + .iter() + .map(|col_idx| columns[*col_idx].datum_at(i)) + .collect(), + ); + let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); + + match op { + Insert | UpdateInsert => { + if self.cache_synced { + self.cache.insert(cache_key, cache_data); + } + state_table.insert(state_row)?; + } + Delete | UpdateDelete => { + if self.cache_synced { + self.cache.remove(cache_key); + } + state_table.delete(state_row)?; + } + } + } + + Ok(()) + } + + async fn get_output_inner( + &mut self, + epoch: u64, + state_table: &StateTable, + ) -> StreamExecutorResult { + if !self.cache_synced { + let all_data_iter = + iter_state_table(state_table, epoch, self.group_key.as_ref()).await?; + pin_mut!(all_data_iter); + + self.cache.clear(); + #[for_await] + for state_row in all_data_iter { + let state_row = state_row?; + let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); + self.cache.insert(cache_key, cache_data.clone()); + } + self.cache_synced = true; + } + + let mut values = Vec::with_capacity(self.cache.len()); + for cache_data in self.cache.iter_values() { + values.push(cache_data.clone()); + } + Ok(Some(ListValue::new(values).into())) + } +} + +#[async_trait] +impl ManagedTableState for ManagedArrayAggState { + async fn apply_chunk( + &mut self, + ops: Ops<'_>, + visibility: Option<&Bitmap>, + columns: &[&ArrayImpl], // contains all upstream columns + _epoch: u64, + state_table: &mut StateTable, + ) -> StreamExecutorResult<()> { + self.apply_chunk_inner(ops, visibility, columns, state_table) + } + + async fn get_output( + &mut self, + epoch: u64, + state_table: &StateTable, + ) -> StreamExecutorResult { + self.get_output_inner(epoch, state_table).await + } + + fn is_dirty(&self) -> bool { + false + } + + fn flush(&mut self, _state_table: &mut StateTable) -> StreamExecutorResult<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use itertools::Itertools; + use risingwave_common::array::{Row, StreamChunk, StreamChunkTestExt}; + use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId}; + use risingwave_common::types::{DataType, ScalarImpl}; + use risingwave_common::util::sort_util::{OrderPair, OrderType}; + use risingwave_expr::expr::AggKind; + use risingwave_storage::memory::MemoryStateStore; + use risingwave_storage::table::streaming_table::state_table::StateTable; + + use super::ManagedArrayAggState; + use crate::common::StateTableColumnMapping; + use crate::executor::aggregation::{AggArgs, AggCall}; + use crate::executor::managed_state::aggregation::ManagedTableState; + use crate::executor::StreamExecutorResult; + + #[tokio::test] + async fn test_array_agg_state_simple_agg_without_order() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + + let input_pk_indices = vec![3]; + let agg_call = AggCall { + kind: AggKind::ArrayAgg, + args: AggArgs::Unary(DataType::Varchar, 0), // array_agg(a) + return_type: DataType::List { + datatype: Box::new(DataType::Varchar), + }, + order_pairs: vec![], + append_only: false, + filter: None, + }; + + // see `LogicalAgg::infer_internal_table_catalog` for the construction of state table + let table_id = TableId::new(6666); + let columns = vec![ + ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64), // _row_id + ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a + ]; + let state_table_col_mapping = Arc::new(StateTableColumnMapping::new(vec![3, 0])); + let mut state_table = StateTable::new_without_distribution( + MemoryStateStore::new(), + table_id, + columns, + vec![OrderType::Ascending], + vec![0], // [_row_id] + ); + + let mut agg_state = + ManagedArrayAggState::new(agg_call, None, input_pk_indices, state_table_col_mapping, 0); + + let mut epoch = 0; + + let chunk = StreamChunk::from_pretty( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + . 7 6 129 + + c 1 3 130", + ); + let (ops, columns, visibility) = chunk.into_inner(); + let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + agg_state + .apply_chunk( + &ops, + visibility.as_ref(), + &column_refs, + epoch, + &mut state_table, + ) + .await?; + + epoch += 1; + agg_state.flush(&mut state_table)?; + state_table.commit(epoch).await.unwrap(); + + let res = agg_state.get_output(epoch, &state_table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_utf8).cloned()) + .collect_vec(); + assert_eq!(res.len(), 3); + assert!(res.contains(&Some("a".to_string()))); + assert!(res.contains(&Some("c".to_string()))); + assert!(res.contains(&None)); + } + _ => panic!("unexpected output"), + } + + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_state_simple_agg_with_order() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + // where `a` is the column to aggregate + + let input_pk_indices = vec![3]; + let agg_call = AggCall { + kind: AggKind::ArrayAgg, + args: AggArgs::Unary(DataType::Int32, 1), // array_agg(b) + return_type: DataType::Int32, + order_pairs: vec![ + OrderPair::new(2, OrderType::Ascending), // c ASC + OrderPair::new(0, OrderType::Descending), // a DESC + ], + append_only: false, + filter: None, + }; + + let table_id = TableId::new(6666); + let columns = vec![ + ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // c + ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a + ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id + ColumnDesc::unnamed(ColumnId::new(3), DataType::Int32), // b + ]; + let state_table_col_mapping = Arc::new(StateTableColumnMapping::new(vec![2, 0, 3, 1])); + let mut state_table = StateTable::new_without_distribution( + MemoryStateStore::new(), + table_id, + columns, + vec![ + OrderType::Ascending, // c ASC + OrderType::Descending, // a DESC + OrderType::Ascending, // _row_id ASC + ], + vec![0, 1, 2], // [c, a, _row_id] + ); + + let mut agg_state = + ManagedArrayAggState::new(agg_call, None, input_pk_indices, state_table_col_mapping, 0); + + let mut epoch = 0; + + { + let chunk = StreamChunk::from_pretty( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + c 2 3 130", + ); + let (ops, columns, visibility) = chunk.into_inner(); + let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + agg_state + .apply_chunk( + &ops, + visibility.as_ref(), + &column_refs, + epoch, + &mut state_table, + ) + .await?; + + agg_state.flush(&mut state_table)?; + state_table.commit(epoch).await.unwrap(); + epoch += 1; + + let res = agg_state.get_output(epoch, &state_table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_int32).cloned()) + .collect_vec(); + assert_eq!(res, vec![Some(2), Some(1)]); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = StreamChunk::from_pretty( + " T i i I + + d 0 8 134 + + e 2 2 137", + ); + let (ops, columns, visibility) = chunk.into_inner(); + let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + agg_state + .apply_chunk( + &ops, + visibility.as_ref(), + &column_refs, + epoch, + &mut state_table, + ) + .await?; + + agg_state.flush(&mut state_table)?; + state_table.commit(epoch).await.unwrap(); + epoch += 1; + + let res = agg_state.get_output(epoch, &state_table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_int32).cloned()) + .collect_vec(); + assert_eq!(res, vec![Some(2), Some(2), Some(0), Some(1)]); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_state_grouped_agg_with_order() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + + let input_pk_indices = vec![3]; + let agg_call = AggCall { + kind: AggKind::ArrayAgg, + args: AggArgs::Unary(DataType::Varchar, 0), + return_type: DataType::Varchar, + order_pairs: vec![ + OrderPair::new(1, OrderType::Ascending), // b ASC + ], + append_only: false, + filter: None, + }; + + let table_id = TableId::new(6666); + let columns = vec![ + ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // group by c + ColumnDesc::unnamed(ColumnId::new(1), DataType::Int32), // order by b + ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id + ColumnDesc::unnamed(ColumnId::new(3), DataType::Varchar), // a + ]; + let state_table_col_mapping = Arc::new(StateTableColumnMapping::new(vec![2, 1, 3, 0])); + let mut state_table = StateTable::new_without_distribution( + MemoryStateStore::new(), + table_id, + columns, + vec![ + OrderType::Ascending, // c ASC + OrderType::Ascending, // b ASC + OrderType::Ascending, // _row_id ASC + ], + vec![0, 1, 2], // [c, b, _row_id] + ); + + let mut agg_state = ManagedArrayAggState::new( + agg_call, + Some(&Row::new(vec![Some(8.into())])), + input_pk_indices, + state_table_col_mapping, + 0, + ); + + let mut epoch = 0; + + { + let chunk = StreamChunk::from_pretty( + " T i i I + + a 1 8 123 + + b 5 8 128 + + c 1 3 130 D // hide this row", + ); + let (ops, columns, visibility) = chunk.into_inner(); + let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + agg_state + .apply_chunk( + &ops, + visibility.as_ref(), + &column_refs, + epoch, + &mut state_table, + ) + .await?; + + agg_state.flush(&mut state_table)?; + state_table.commit(epoch).await.unwrap(); + epoch += 1; + + let res = agg_state.get_output(epoch, &state_table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_utf8).cloned()) + .collect_vec(); + assert_eq!(res, vec![Some("a".to_string()), Some("b".to_string())]); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = StreamChunk::from_pretty( + " T i i I + + d 0 2 134 D // hide this row + + e 2 8 137", + ); + let (ops, columns, visibility) = chunk.into_inner(); + let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + agg_state + .apply_chunk( + &ops, + visibility.as_ref(), + &column_refs, + epoch, + &mut state_table, + ) + .await?; + + agg_state.flush(&mut state_table)?; + state_table.commit(epoch).await.unwrap(); + epoch += 1; + + let res = agg_state.get_output(epoch, &state_table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_utf8).cloned()) + .collect_vec(); + assert_eq!( + res, + vec![ + Some("a".to_string()), + Some("e".to_string()), + Some("b".to_string()) + ] + ); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } +} diff --git a/src/stream/src/executor/managed_state/aggregation/mod.rs b/src/stream/src/executor/managed_state/aggregation/mod.rs index 6e6b32a51bf17..672857a69bb65 100644 --- a/src/stream/src/executor/managed_state/aggregation/mod.rs +++ b/src/stream/src/executor/managed_state/aggregation/mod.rs @@ -30,7 +30,8 @@ pub use value::*; use crate::common::StateTableColumnMapping; use crate::executor::aggregation::AggCall; -use crate::executor::error::{StreamExecutorError, StreamExecutorResult}; +use crate::executor::error::StreamExecutorResult; +use crate::executor::managed_state::aggregation::array_agg::ManagedArrayAggState; use crate::executor::managed_state::aggregation::string_agg::ManagedStringAggState; use crate::executor::PkIndices; @@ -38,8 +39,8 @@ use crate::executor::PkIndices; // TODO: estimate a good cache size instead of hard-coding const EXTREME_CACHE_SIZE: usize = 1024; +mod array_agg; mod extreme; - mod string_agg; mod value; @@ -227,10 +228,13 @@ impl ManagedStateImpl { state_table_col_mapping, row_count.unwrap(), )))), - AggKind::ArrayAgg => Err(StreamExecutorError::not_implemented( - "ArrayAgg is not implemented yet", - 4657, - )), + AggKind::ArrayAgg => Ok(Self::Table(Box::new(ManagedArrayAggState::new( + agg_call, + group_key, + pk_indices, + state_table_col_mapping, + row_count.unwrap(), + )))), } } }