Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batch,agg): support array_agg for batch mode #4862

Merged
merged 10 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions e2e_test/batch/aggregate/array_agg.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t(v1 varchar, v2 int, v3 int)

statement ok
insert into t values ('aaa', 1, 1), ('bbb', 0, 2), ('ccc', 0, 5), ('ddd', 1, 4)

query T
select b from (select unnest(a) from (select array_agg(v3) as v3_arr from t) g(a)) p(b) order by b;
----
1
2
4
5

query T
select array_agg(v1 order by v3 desc) from t
----
{ccc,ddd,bbb,aaa}

query T
select array_agg(v1 order by v2 asc, v3 desc) from t
----
{ccc,bbb,ddd,aaa}

statement ok
drop table t
5 changes: 5 additions & 0 deletions src/expr/src/vector_op/agg/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use risingwave_common::util::sort_util::{OrderPair, OrderType};
use risingwave_pb::expr::AggCall;
use risingwave_pb::plan_common::OrderType as ProstOrderType;

use super::array_agg::create_array_agg_state;
use super::string_agg::StringAgg;
use crate::expr::{
build_from_prost, AggKind, Expression, ExpressionRef, InputRefExpression, LiteralExpression,
Expand Down Expand Up @@ -123,6 +124,10 @@ impl AggStateFactory {
order_col_types,
))
}
(AggKind::ArrayAgg, [arg]) => {
let agg_col_idx = arg.get_input()?.get_column_idx() as usize;
create_array_agg_state(return_type.clone(), agg_col_idx, order_pairs)?
}
(agg_kind, [arg]) => {
// other unary agg call
let input_type = DataType::from(arg.get_type()?);
Expand Down
264 changes: 264 additions & 0 deletions src/expr/src/vector_op/agg/array_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
// Copyright 2022 Singularity Data
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::{ArrayBuilder, ArrayBuilderImpl, DataChunk, ListValue, RowRef};
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::{DataType, Datum, Scalar};
use risingwave_common::util::ordered::OrderedRow;
use risingwave_common::util::sort_util::{OrderPair, OrderType};

use crate::vector_op::agg::aggregator::Aggregator;

#[derive(Clone)]
struct ArrayAggUnordered {
return_type: DataType,
agg_col_idx: usize,
values: Vec<Datum>,
}

impl ArrayAggUnordered {
fn new(return_type: DataType, agg_col_idx: usize) -> Self {
debug_assert!(matches!(return_type, DataType::List { datatype: _ }));
ArrayAggUnordered {
return_type,
agg_col_idx,
values: vec![],
}
}

fn push(&mut self, datum: Datum) {
self.values.push(datum);
}

fn get_result_and_reset(&mut self) -> ListValue {
ListValue::new(std::mem::take(&mut self.values))
}
}

impl Aggregator for ArrayAggUnordered {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> {
let array = input.column_at(self.agg_col_idx).array_ref();
self.push(array.datum_at(row_id));
Ok(())
}

fn update_multi(
&mut self,
input: &DataChunk,
start_row_id: usize,
end_row_id: usize,
) -> Result<()> {
self.values.reserve(end_row_id - start_row_id);
for row_id in start_row_id..end_row_id {
stdrc marked this conversation as resolved.
Show resolved Hide resolved
self.update_single(input, row_id)?;
}
Ok(())
}

fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Are we still using RwResult for aggregators? 🥵

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened a PR for this: #4873

if let ArrayBuilderImpl::List(builder) = builder {
builder
.append(Some(self.get_result_and_reset().as_scalar_ref()))
.map_err(Into::into)
} else {
Err(
ErrorCode::InternalError(format!("Builder fail to match {}.", stringify!(Utf8)))
.into(),
)
}
}
}

#[derive(Clone)]
struct ArrayAggOrdered {
return_type: DataType,
agg_col_idx: usize,
order_col_indices: Vec<usize>,
order_types: Vec<OrderType>,
unordered_values: Vec<(OrderedRow, Datum)>,
}

impl ArrayAggOrdered {
fn new(return_type: DataType, agg_col_idx: usize, order_pairs: Vec<OrderPair>) -> Self {
debug_assert!(matches!(return_type, DataType::List { datatype: _ }));
let (order_col_indices, order_types) = order_pairs
.into_iter()
.map(|p| (p.column_idx, p.order_type))
.unzip();
ArrayAggOrdered {
return_type,
agg_col_idx,
order_col_indices,
order_types,
unordered_values: vec![],
}
}

fn push_row(&mut self, row: RowRef) {
let key = OrderedRow::new(
row.row_by_indices(&self.order_col_indices),
&self.order_types,
);
let datum = row.value_at(self.agg_col_idx).map(|x| x.into_scalar_impl());
self.unordered_values.push((key, datum));
}

fn get_result_and_reset(&mut self) -> ListValue {
let mut rows = std::mem::take(&mut self.unordered_values);
rows.sort_unstable_by(|a, b| a.0.cmp(&b.0));
ListValue::new(rows.into_iter().map(|(_, datum)| datum).collect())
}
}

impl Aggregator for ArrayAggOrdered {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> {
let (row, vis) = input.row_at(row_id)?;
assert!(vis);
self.push_row(row);
Ok(())
}

fn update_multi(
&mut self,
input: &DataChunk,
start_row_id: usize,
end_row_id: usize,
) -> Result<()> {
self.unordered_values.reserve(end_row_id - start_row_id);
for row_id in start_row_id..end_row_id {
self.update_single(input, row_id)?;
}
Ok(())
}

fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> {
if let ArrayBuilderImpl::List(builder) = builder {
builder
.append(Some(self.get_result_and_reset().as_scalar_ref()))
.map_err(Into::into)
} else {
Err(
ErrorCode::InternalError(format!("Builder fail to match {}.", stringify!(Utf8)))
.into(),
)
}
}
}

pub fn create_array_agg_state(
return_type: DataType,
agg_col_idx: usize,
order_pairs: Vec<OrderPair>,
) -> Result<Box<dyn Aggregator>> {
if order_pairs.is_empty() {
Ok(Box::new(ArrayAggUnordered::new(return_type, agg_col_idx)))
} else {
Ok(Box::new(ArrayAggOrdered::new(
return_type,
agg_col_idx,
order_pairs,
)))
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use risingwave_common::array::Array;
use risingwave_common::test_prelude::DataChunkTestExt;
use risingwave_common::types::ScalarRef;

use super::*;

#[test]
fn test_array_agg_basic() -> Result<()> {
let chunk = DataChunk::from_pretty(
"i
123
456
789",
);
let return_type = DataType::List {
datatype: Box::new(DataType::Int32),
};
let mut agg = create_array_agg_state(return_type.clone(), 0, vec![])?;
let mut builder = return_type.create_array_builder(0);
agg.update_multi(&chunk, 0, chunk.cardinality())?;
agg.output(&mut builder)?;
let output = builder.finish()?;
let actual = output.into_list();
let actual = actual
.iter()
.map(|v| v.map(|s| s.to_owned_scalar()))
.collect_vec();
assert_eq!(
actual,
vec![Some(ListValue::new(vec![
Some(123.into()),
Some(456.into()),
Some(789.into())
]))]
);
Ok(())
}

#[test]
fn test_array_agg_with_order() -> Result<()> {
let chunk = DataChunk::from_pretty(
"i i
123 3
456 2
789 2
321 9",
);
let return_type = DataType::List {
datatype: Box::new(DataType::Int32),
};
let mut agg = create_array_agg_state(
return_type.clone(),
0,
vec![
OrderPair::new(1, OrderType::Ascending),
OrderPair::new(0, OrderType::Descending),
],
)?;
let mut builder = return_type.create_array_builder(0);
agg.update_multi(&chunk, 0, chunk.cardinality())?;
agg.output(&mut builder)?;
let output = builder.finish()?;
let actual = output.into_list();
let actual = actual
.iter()
.map(|v| v.map(|s| s.to_owned_scalar()))
.collect_vec();
assert_eq!(
actual,
vec![Some(ListValue::new(vec![
Some(789.into()),
Some(456.into()),
Some(123.into()),
Some(321.into())
]))]
);
Ok(())
}
}
1 change: 1 addition & 0 deletions src/expr/src/vector_op/agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod aggregator;
mod approx_count_distinct;
mod array_agg;
mod count_star;
mod functions;
mod general_agg;
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ impl LogicalAgg {
pub(crate) fn is_agg_result_affected_by_order(&self) -> bool {
self.agg_calls
.iter()
.any(|call| matches!(call.agg_kind, AggKind::StringAgg))
.any(|call| matches!(call.agg_kind, AggKind::StringAgg | AggKind::ArrayAgg))
}
}

Expand Down