From 80f1d58b866b022fd2d5aac92dd2e5bae2c6ceb7 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 11 Oct 2023 16:27:26 +0800 Subject: [PATCH] feat(expr): support streaming `bit_and` and `bit_or` aggregate (#12758) Signed-off-by: Runji Wang --- src/common/src/array/list_array.rs | 30 ++- src/expr/core/src/aggregate/def.rs | 15 +- src/expr/impl/src/aggregate/bit_and.rs | 191 ++++++++++++++++++ src/expr/impl/src/aggregate/bit_or.rs | 149 ++++++++++++++ src/expr/impl/src/aggregate/bit_xor.rs | 52 +++++ src/expr/impl/src/aggregate/general.rs | 27 --- src/expr/impl/src/aggregate/mod.rs | 3 + src/expr/impl/src/scalar/array_access.rs | 2 +- .../src/table_function/generate_subscripts.rs | 2 +- src/expr/macro/src/gen.rs | 19 +- src/expr/macro/src/lib.rs | 3 + src/expr/macro/src/parse.rs | 12 +- 12 files changed, 455 insertions(+), 50 deletions(-) create mode 100644 src/expr/impl/src/aggregate/bit_and.rs create mode 100644 src/expr/impl/src/aggregate/bit_or.rs create mode 100644 src/expr/impl/src/aggregate/bit_xor.rs diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 6445ac8a156d3..7eaaffff98534 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::fmt; use std::cmp::Ordering; +use std::fmt; use std::fmt::Debug; use std::future::Future; use std::hash::Hash; use std::mem::size_of; +use std::ops::{Index, IndexMut}; use bytes::{Buf, BufMut}; use either::Either; @@ -359,6 +360,20 @@ impl Ord for ListValue { } } +impl Index for ListValue { + type Output = Datum; + + fn index(&self, index: usize) -> &Self::Output { + &self.values[index] + } +} + +impl IndexMut for ListValue { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.values[index] + } +} + // Used to display ListValue in explain for better readibilty. pub fn display_for_explain(list: &ListValue) -> String { // Example of ListValue display: ARRAY[1, 2, null] @@ -485,7 +500,7 @@ impl<'a> ListRef<'a> { } /// Get the element at the given index. Returns `None` if the index is out of bounds. - pub fn elem_at(self, index: usize) -> Option> { + pub fn get(self, index: usize) -> Option> { iter_elems_ref!(self, it, { let mut it = it; it.nth(index) @@ -551,12 +566,9 @@ impl Ord for ListRef<'_> { impl Debug for ListRef<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - iter_elems_ref!(*self, it, { - for v in it { - Debug::fmt(&v, f)?; - } - Ok(()) - }) + let mut f = f.debug_list(); + iter_elems_ref!(*self, it, { f.entries(it) }); + f.finish() } } @@ -1020,7 +1032,7 @@ mod tests { ); // Get 2nd value from ListRef - let scalar = list_ref.elem_at(1).unwrap(); + let scalar = list_ref.get(1).unwrap(); assert_eq!(scalar, Some(types::ScalarRefImpl::Int32(5))); } } diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 2d6763130cc4a..39d4c158c10d7 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -308,9 +308,7 @@ pub mod agg_kinds { #[macro_export] macro_rules! unimplemented_in_stream { () => { - AggKind::BitAnd - | AggKind::BitOr - | AggKind::JsonbAgg + AggKind::JsonbAgg | AggKind::JsonbObjectAgg | AggKind::PercentileCont | AggKind::PercentileDisc @@ -408,6 +406,8 @@ pub mod agg_kinds { // after we support general merge in stateless_simple_agg | AggKind::BoolAnd | AggKind::BoolOr + | AggKind::BitAnd + | AggKind::BitOr }; } pub use simply_cannot_two_phase; @@ -420,6 +420,8 @@ pub mod agg_kinds { AggKind::Sum | AggKind::Sum0 | AggKind::Count + | AggKind::BitAnd + | AggKind::BitOr | AggKind::BitXor | AggKind::BoolAnd | AggKind::BoolOr @@ -452,12 +454,7 @@ impl AggKind { /// Get the total phase agg kind from the partial phase agg kind. pub fn partial_to_total(self) -> Option { match self { - AggKind::BitAnd - | AggKind::BitOr - | AggKind::BitXor - | AggKind::Min - | AggKind::Max - | AggKind::Sum => Some(self), + AggKind::BitXor | AggKind::Min | AggKind::Max | AggKind::Sum => Some(self), AggKind::Sum0 | AggKind::Count => Some(AggKind::Sum0), agg_kinds::simply_cannot_two_phase!() => None, agg_kinds::rewritten!() => None, diff --git a/src/expr/impl/src/aggregate/bit_and.rs b/src/expr/impl/src/aggregate/bit_and.rs new file mode 100644 index 0000000000000..879f81704b14a --- /dev/null +++ b/src/expr/impl/src/aggregate/bit_and.rs @@ -0,0 +1,191 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::marker::PhantomData; +use std::ops::BitAnd; + +use risingwave_common::types::{ListRef, ListValue, ScalarImpl}; +use risingwave_expr::aggregate; + +/// Computes the bitwise AND of all non-null input values. +/// +/// # Example +/// +/// ```slt +/// statement ok +/// create table t (a int2, b int4, c int8); +/// +/// query III +/// select bit_and(a), bit_and(b), bit_and(c) from t; +/// ---- +/// NULL NULL NULL +/// +/// statement ok +/// insert into t values +/// (6, 6, 6), +/// (3, 3, 3), +/// (null, null, null); +/// +/// query III +/// select bit_and(a), bit_and(b), bit_and(c) from t; +/// ---- +/// 2 2 2 +/// +/// statement ok +/// drop table t; +/// ``` +// XXX: state = "ref" is required so that +// for the first non-null value, the state is set to that value. +#[aggregate("bit_and(*int) -> auto", state = "ref")] +fn bit_and_append_only(state: T, input: T) -> T +where + T: BitAnd, +{ + state.bitand(input) +} + +/// Computes the bitwise AND of all non-null input values. +/// +/// # Example +/// +/// ```slt +/// statement ok +/// create table t (a int2, b int4, c int8); +/// +/// statement ok +/// create materialized view mv as +/// select bit_and(a) a, bit_and(b) b, bit_and(c) c from t; +/// +/// query III +/// select * from mv; +/// ---- +/// NULL NULL NULL +/// +/// statement ok +/// insert into t values +/// (6, 6, 6), +/// (3, 3, 3), +/// (null, null, null); +/// +/// query III +/// select * from mv; +/// ---- +/// 2 2 2 +/// +/// statement ok +/// delete from t where a = 3; +/// +/// query III +/// select * from mv; +/// ---- +/// 6 6 6 +/// +/// statement ok +/// drop materialized view mv; +/// +/// statement ok +/// drop table t; +/// ``` +#[derive(Debug, Default, Clone)] +struct BitAndUpdatable { + _phantom: PhantomData, +} + +#[aggregate("bit_and(int2) -> int2", state = "int8[]", generic = "i16")] +#[aggregate("bit_and(int4) -> int4", state = "int8[]", generic = "i32")] +#[aggregate("bit_and(int8) -> int8", state = "int8[]", generic = "i64")] +impl BitAndUpdatable { + // state is the number of 0s for each bit. + + fn create_state(&self) -> ListValue { + ListValue::new(vec![Some(ScalarImpl::Int64(0)); T::BITS]) + } + + fn accumulate(&self, mut state: ListValue, input: T) -> ListValue { + for i in 0..T::BITS { + if !input.get_bit(i) { + let Some(ScalarImpl::Int64(count)) = &mut state[i] else { + panic!("invalid state"); + }; + *count += 1; + } + } + state + } + + fn retract(&self, mut state: ListValue, input: T) -> ListValue { + for i in 0..T::BITS { + if !input.get_bit(i) { + let Some(ScalarImpl::Int64(count)) = &mut state[i] else { + panic!("invalid state"); + }; + *count -= 1; + } + } + state + } + + fn finalize(&self, state: ListRef<'_>) -> T { + let mut result = T::default(); + for i in 0..T::BITS { + let count = state.get(i).unwrap().unwrap().into_int64(); + if count == 0 { + result.set_bit(i); + } + } + result + } +} + +pub trait Bits: Default { + const BITS: usize; + fn get_bit(&self, i: usize) -> bool; + fn set_bit(&mut self, i: usize); +} + +impl Bits for i16 { + const BITS: usize = 16; + + fn get_bit(&self, i: usize) -> bool { + (*self >> i) & 1 == 1 + } + + fn set_bit(&mut self, i: usize) { + *self |= 1 << i; + } +} + +impl Bits for i32 { + const BITS: usize = 32; + + fn get_bit(&self, i: usize) -> bool { + (*self >> i) & 1 == 1 + } + + fn set_bit(&mut self, i: usize) { + *self |= 1 << i; + } +} + +impl Bits for i64 { + const BITS: usize = 64; + + fn get_bit(&self, i: usize) -> bool { + (*self >> i) & 1 == 1 + } + + fn set_bit(&mut self, i: usize) { + *self |= 1 << i; + } +} diff --git a/src/expr/impl/src/aggregate/bit_or.rs b/src/expr/impl/src/aggregate/bit_or.rs new file mode 100644 index 0000000000000..1bf205f335e8b --- /dev/null +++ b/src/expr/impl/src/aggregate/bit_or.rs @@ -0,0 +1,149 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::marker::PhantomData; +use std::ops::BitOr; + +use risingwave_common::types::{ListRef, ListValue, ScalarImpl}; +use risingwave_expr::aggregate; + +use super::bit_and::Bits; + +/// Computes the bitwise OR of all non-null input values. +/// +/// # Example +/// +/// ```slt +/// statement ok +/// create table t (a int2, b int4, c int8); +/// +/// query III +/// select bit_or(a), bit_or(b), bit_or(c) from t; +/// ---- +/// NULL NULL NULL +/// +/// statement ok +/// insert into t values +/// (1, 1, 1), +/// (2, 2, 2), +/// (null, null, null); +/// +/// query III +/// select bit_or(a), bit_or(b), bit_or(c) from t; +/// ---- +/// 3 3 3 +/// +/// statement ok +/// drop table t; +/// ``` +#[aggregate("bit_or(*int) -> auto")] +fn bit_or_append_only(state: T, input: T) -> T +where + T: BitOr, +{ + state.bitor(input) +} + +/// Computes the bitwise OR of all non-null input values. +/// +/// # Example +/// +/// ```slt +/// statement ok +/// create table t (a int2, b int4, c int8); +/// +/// statement ok +/// create materialized view mv as +/// select bit_or(a) a, bit_or(b) b, bit_or(c) c from t; +/// +/// query III +/// select * from mv; +/// ---- +/// NULL NULL NULL +/// +/// statement ok +/// insert into t values +/// (6, 6, 6), +/// (3, 3, 3), +/// (null, null, null); +/// +/// query III +/// select * from mv; +/// ---- +/// 7 7 7 +/// +/// statement ok +/// delete from t where a = 3; +/// +/// query III +/// select * from mv; +/// ---- +/// 6 6 6 +/// +/// statement ok +/// drop materialized view mv; +/// +/// statement ok +/// drop table t; +/// ``` +#[derive(Debug, Default, Clone)] +struct BitOrUpdatable { + _phantom: PhantomData, +} + +#[aggregate("bit_or(int2) -> int2", state = "int8[]", generic = "i16")] +#[aggregate("bit_or(int4) -> int4", state = "int8[]", generic = "i32")] +#[aggregate("bit_or(int8) -> int8", state = "int8[]", generic = "i64")] +impl BitOrUpdatable { + // state is the number of 1s for each bit. + + fn create_state(&self) -> ListValue { + ListValue::new(vec![Some(ScalarImpl::Int64(0)); T::BITS]) + } + + fn accumulate(&self, mut state: ListValue, input: T) -> ListValue { + for i in 0..T::BITS { + if input.get_bit(i) { + let Some(ScalarImpl::Int64(count)) = &mut state[i] else { + panic!("invalid state"); + }; + *count += 1; + } + } + state + } + + fn retract(&self, mut state: ListValue, input: T) -> ListValue { + for i in 0..T::BITS { + if input.get_bit(i) { + let Some(ScalarImpl::Int64(count)) = &mut state[i] else { + panic!("invalid state"); + }; + *count -= 1; + } + } + state + } + + fn finalize(&self, state: ListRef<'_>) -> T { + let mut result = T::default(); + for i in 0..T::BITS { + let count = state.get(i).unwrap().unwrap().into_int64(); + if count != 0 { + result.set_bit(i); + } + } + result + } +} diff --git a/src/expr/impl/src/aggregate/bit_xor.rs b/src/expr/impl/src/aggregate/bit_xor.rs new file mode 100644 index 0000000000000..d86098105daf7 --- /dev/null +++ b/src/expr/impl/src/aggregate/bit_xor.rs @@ -0,0 +1,52 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::BitXor; + +use risingwave_expr::aggregate; + +/// Computes the bitwise XOR of all non-null input values. +/// +/// # Example +/// +/// ```slt +/// statement ok +/// create table t (a int2, b int4, c int8); +/// +/// query III +/// select bit_xor(a), bit_xor(b), bit_xor(c) from t; +/// ---- +/// NULL NULL NULL +/// +/// statement ok +/// insert into t values +/// (3, 3, 3), +/// (6, 6, 6), +/// (null, null, null); +/// +/// query III +/// select bit_xor(a), bit_xor(b), bit_xor(c) from t; +/// ---- +/// 5 5 5 +/// +/// statement ok +/// drop table t; +/// ``` +#[aggregate("bit_xor(*int) -> auto")] +fn bit_xor(state: T, input: T, _retract: bool) -> T +where + T: BitXor, +{ + state.bitxor(input) +} diff --git a/src/expr/impl/src/aggregate/general.rs b/src/expr/impl/src/aggregate/general.rs index dfdf1967d554c..de1331c524063 100644 --- a/src/expr/impl/src/aggregate/general.rs +++ b/src/expr/impl/src/aggregate/general.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::convert::From; -use std::ops::{BitAnd, BitOr, BitXor}; use num_traits::{CheckedAdd, CheckedSub}; use risingwave_expr::{aggregate, ExprError, Result}; @@ -53,32 +52,6 @@ fn max(state: T, input: T) -> T { state.max(input) } -// XXX: state = "ref" is required so that -// for the first non-null value, the state is set to that value. -#[aggregate("bit_and(*int) -> auto", state = "ref")] -fn bit_and(state: T, input: T) -> T -where - T: BitAnd, -{ - state.bitand(input) -} - -#[aggregate("bit_or(*int) -> auto")] -fn bit_or(state: T, input: T) -> T -where - T: BitOr, -{ - state.bitor(input) -} - -#[aggregate("bit_xor(*int) -> auto")] -fn bit_xor(state: T, input: T, _retract: bool) -> T -where - T: BitXor, -{ - state.bitxor(input) -} - #[aggregate("first_value(*) -> auto", state = "ref")] fn first_value(state: T, _: T) -> T { state diff --git a/src/expr/impl/src/aggregate/mod.rs b/src/expr/impl/src/aggregate/mod.rs index 22cfa7ce8b588..d1373acae31b2 100644 --- a/src/expr/impl/src/aggregate/mod.rs +++ b/src/expr/impl/src/aggregate/mod.rs @@ -14,6 +14,9 @@ mod approx_count_distinct; mod array_agg; +mod bit_and; +mod bit_or; +mod bit_xor; mod bool_and; mod bool_or; mod general; diff --git a/src/expr/impl/src/scalar/array_access.rs b/src/expr/impl/src/scalar/array_access.rs index 929eb19b9318b..2ac39be99ac8c 100644 --- a/src/expr/impl/src/scalar/array_access.rs +++ b/src/expr/impl/src/scalar/array_access.rs @@ -23,7 +23,7 @@ fn array_access(list: ListRef<'_>, index: i32) -> Option> { return None; } // returns `NULL` if index is out of bounds - list.elem_at(index as usize - 1).flatten() + list.get(index as usize - 1).flatten() } #[cfg(test)] diff --git a/src/expr/impl/src/table_function/generate_subscripts.rs b/src/expr/impl/src/table_function/generate_subscripts.rs index 53123489d7976..c3ecd6afabf5b 100644 --- a/src/expr/impl/src/table_function/generate_subscripts.rs +++ b/src/expr/impl/src/table_function/generate_subscripts.rs @@ -130,7 +130,7 @@ fn generate_subscripts_inner(array: ListRef<'_>, dim: i32) -> (i32, i32) { ..=0 => nothing, 1 => (1, array.len() as i32 + 1), // Although RW's array can be zig-zag, we just look at the first element. - 2.. => match array.elem_at(0) { + 2.. => match array.get(0) { Some(Some(ScalarRefImpl::List(list))) => generate_subscripts_inner(list, dim - 1), _ => nothing, }, diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 083f184add5e7..9155853df5b7b 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -675,10 +675,17 @@ impl FunctionAttr { } 1 => { let first_state = if self.init_state.is_some() { + // for count, the state will never be None quote! { unreachable!() } } else if let Some(s) = &self.state && s == "ref" { // for min/max/first/last, the state is the first value quote! { Some(v0) } + } else if let AggregateFnOrImpl::Impl(impl_) = user_fn && impl_.create_state.is_some() { + // use user-defined create_state function + quote! {{ + let state = self.function.create_state(); + #next_state + }} } else { quote! {{ let state = #state_type::default(); @@ -712,14 +719,22 @@ impl FunctionAttr { AggregateFnOrImpl::Fn(_) => quote! {}, AggregateFnOrImpl::Impl(i) => { let struct_name = format_ident!("{}", i.struct_name); - quote! { function: #struct_name, } + let generic = self.generic.as_ref().map(|g| { + let g = format_ident!("{g}"); + quote! { <#g> } + }); + quote! { function: #struct_name #generic, } } }; let function_new = match user_fn { AggregateFnOrImpl::Fn(_) => quote! {}, AggregateFnOrImpl::Impl(i) => { let struct_name = format_ident!("{}", i.struct_name); - quote! { function: #struct_name::default(), } + let generic = self.generic.as_ref().map(|g| { + let g = format_ident!("{g}"); + quote! { ::<#g> } + }); + quote! { function: #struct_name #generic :: default(), } } }; diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index cb57d0cf75383..24760d06f4341 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -495,6 +495,8 @@ struct FunctionAttr { prebuild: Option, /// Type inference function. type_infer: Option, + /// Generic type. + generic: Option, /// Whether the function is volatile. volatile: bool, /// Whether the function is deprecated. @@ -536,6 +538,7 @@ struct AggregateImpl { #[allow(dead_code)] // TODO(wrj): add merge to trait merge: Option, finalize: Option, + create_state: Option, #[allow(dead_code)] // TODO(wrj): support encode encode_state: Option, #[allow(dead_code)] // TODO(wrj): support decode diff --git a/src/expr/macro/src/parse.rs b/src/expr/macro/src/parse.rs index be7a6c86df624..24cc6942afcee 100644 --- a/src/expr/macro/src/parse.rs +++ b/src/expr/macro/src/parse.rs @@ -75,6 +75,8 @@ impl Parse for FunctionAttr { parsed.prebuild = Some(get_value()?); } else if meta.path().is_ident("type_infer") { parsed.type_infer = Some(get_value()?); + } else if meta.path().is_ident("generic") { + parsed.generic = Some(get_value()?); } else if meta.path().is_ident("volatile") { parsed.volatile = true; } else if meta.path().is_ident("deprecated") { @@ -141,12 +143,18 @@ impl Parse for AggregateImpl { _ => None, }) }; + let self_path = itemimpl.self_ty.to_token_stream().to_string(); + let struct_name = match self_path.split_once('<') { + Some((path, _)) => path.trim().into(), // remove generic parameters + None => self_path, + }; Ok(AggregateImpl { - struct_name: itemimpl.self_ty.to_token_stream().to_string(), + struct_name, accumulate: parse_function("accumulate").expect("expect accumulate function"), retract: parse_function("retract"), merge: parse_function("merge"), finalize: parse_function("finalize"), + create_state: parse_function("create_state"), encode_state: parse_function("encode_state"), decode_state: parse_function("decode_state"), }) @@ -155,6 +163,8 @@ impl Parse for AggregateImpl { impl Parse for AggregateFnOrImpl { fn parse(input: ParseStream<'_>) -> Result { + // consume attributes + let _ = input.call(syn::Attribute::parse_outer)?; if input.peek(Token![impl]) { Ok(AggregateFnOrImpl::Impl(input.parse()?)) } else {