Skip to content

Commit

Permalink
feat(expr): support streaming bit_and and bit_or aggregate (#12758)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Oct 11, 2023
1 parent 88ffd82 commit 80f1d58
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 50 deletions.
30 changes: 21 additions & 9 deletions src/common/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::fmt;
use std::cmp::Ordering;
use std::fmt;
use std::fmt::Debug;
use std::future::Future;
use std::hash::Hash;
use std::mem::size_of;
use std::ops::{Index, IndexMut};

use bytes::{Buf, BufMut};
use either::Either;
Expand Down Expand Up @@ -359,6 +360,20 @@ impl Ord for ListValue {
}
}

impl Index<usize> for ListValue {
type Output = Datum;

fn index(&self, index: usize) -> &Self::Output {
&self.values[index]
}
}

impl IndexMut<usize> for ListValue {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.values[index]
}
}

// Used to display ListValue in explain for better readibilty.
pub fn display_for_explain(list: &ListValue) -> String {
// Example of ListValue display: ARRAY[1, 2, null]
Expand Down Expand Up @@ -485,7 +500,7 @@ impl<'a> ListRef<'a> {
}

/// Get the element at the given index. Returns `None` if the index is out of bounds.
pub fn elem_at(self, index: usize) -> Option<DatumRef<'a>> {
pub fn get(self, index: usize) -> Option<DatumRef<'a>> {
iter_elems_ref!(self, it, {
let mut it = it;
it.nth(index)
Expand Down Expand Up @@ -551,12 +566,9 @@ impl Ord for ListRef<'_> {

impl Debug for ListRef<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
iter_elems_ref!(*self, it, {
for v in it {
Debug::fmt(&v, f)?;
}
Ok(())
})
let mut f = f.debug_list();
iter_elems_ref!(*self, it, { f.entries(it) });
f.finish()
}
}

Expand Down Expand Up @@ -1020,7 +1032,7 @@ mod tests {
);

// Get 2nd value from ListRef
let scalar = list_ref.elem_at(1).unwrap();
let scalar = list_ref.get(1).unwrap();
assert_eq!(scalar, Some(types::ScalarRefImpl::Int32(5)));
}
}
15 changes: 6 additions & 9 deletions src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ pub mod agg_kinds {
#[macro_export]
macro_rules! unimplemented_in_stream {
() => {
AggKind::BitAnd
| AggKind::BitOr
| AggKind::JsonbAgg
AggKind::JsonbAgg
| AggKind::JsonbObjectAgg
| AggKind::PercentileCont
| AggKind::PercentileDisc
Expand Down Expand Up @@ -408,6 +406,8 @@ pub mod agg_kinds {
// after we support general merge in stateless_simple_agg
| AggKind::BoolAnd
| AggKind::BoolOr
| AggKind::BitAnd
| AggKind::BitOr
};
}
pub use simply_cannot_two_phase;
Expand All @@ -420,6 +420,8 @@ pub mod agg_kinds {
AggKind::Sum
| AggKind::Sum0
| AggKind::Count
| AggKind::BitAnd
| AggKind::BitOr
| AggKind::BitXor
| AggKind::BoolAnd
| AggKind::BoolOr
Expand Down Expand Up @@ -452,12 +454,7 @@ impl AggKind {
/// Get the total phase agg kind from the partial phase agg kind.
pub fn partial_to_total(self) -> Option<Self> {
match self {
AggKind::BitAnd
| AggKind::BitOr
| AggKind::BitXor
| AggKind::Min
| AggKind::Max
| AggKind::Sum => Some(self),
AggKind::BitXor | AggKind::Min | AggKind::Max | AggKind::Sum => Some(self),
AggKind::Sum0 | AggKind::Count => Some(AggKind::Sum0),
agg_kinds::simply_cannot_two_phase!() => None,
agg_kinds::rewritten!() => None,
Expand Down
191 changes: 191 additions & 0 deletions src/expr/impl/src/aggregate/bit_and.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::marker::PhantomData;
use std::ops::BitAnd;

use risingwave_common::types::{ListRef, ListValue, ScalarImpl};
use risingwave_expr::aggregate;

/// Computes the bitwise AND of all non-null input values.
///
/// # Example
///
/// ```slt
/// statement ok
/// create table t (a int2, b int4, c int8);
///
/// query III
/// select bit_and(a), bit_and(b), bit_and(c) from t;
/// ----
/// NULL NULL NULL
///
/// statement ok
/// insert into t values
/// (6, 6, 6),
/// (3, 3, 3),
/// (null, null, null);
///
/// query III
/// select bit_and(a), bit_and(b), bit_and(c) from t;
/// ----
/// 2 2 2
///
/// statement ok
/// drop table t;
/// ```
// XXX: state = "ref" is required so that
// for the first non-null value, the state is set to that value.
#[aggregate("bit_and(*int) -> auto", state = "ref")]
fn bit_and_append_only<T>(state: T, input: T) -> T
where
T: BitAnd<Output = T>,
{
state.bitand(input)
}

/// Computes the bitwise AND of all non-null input values.
///
/// # Example
///
/// ```slt
/// statement ok
/// create table t (a int2, b int4, c int8);
///
/// statement ok
/// create materialized view mv as
/// select bit_and(a) a, bit_and(b) b, bit_and(c) c from t;
///
/// query III
/// select * from mv;
/// ----
/// NULL NULL NULL
///
/// statement ok
/// insert into t values
/// (6, 6, 6),
/// (3, 3, 3),
/// (null, null, null);
///
/// query III
/// select * from mv;
/// ----
/// 2 2 2
///
/// statement ok
/// delete from t where a = 3;
///
/// query III
/// select * from mv;
/// ----
/// 6 6 6
///
/// statement ok
/// drop materialized view mv;
///
/// statement ok
/// drop table t;
/// ```
#[derive(Debug, Default, Clone)]
struct BitAndUpdatable<T> {
_phantom: PhantomData<T>,
}

#[aggregate("bit_and(int2) -> int2", state = "int8[]", generic = "i16")]
#[aggregate("bit_and(int4) -> int4", state = "int8[]", generic = "i32")]
#[aggregate("bit_and(int8) -> int8", state = "int8[]", generic = "i64")]
impl<T: Bits> BitAndUpdatable<T> {
// state is the number of 0s for each bit.

fn create_state(&self) -> ListValue {
ListValue::new(vec![Some(ScalarImpl::Int64(0)); T::BITS])
}

fn accumulate(&self, mut state: ListValue, input: T) -> ListValue {
for i in 0..T::BITS {
if !input.get_bit(i) {
let Some(ScalarImpl::Int64(count)) = &mut state[i] else {
panic!("invalid state");
};
*count += 1;
}
}
state
}

fn retract(&self, mut state: ListValue, input: T) -> ListValue {
for i in 0..T::BITS {
if !input.get_bit(i) {
let Some(ScalarImpl::Int64(count)) = &mut state[i] else {
panic!("invalid state");
};
*count -= 1;
}
}
state
}

fn finalize(&self, state: ListRef<'_>) -> T {
let mut result = T::default();
for i in 0..T::BITS {
let count = state.get(i).unwrap().unwrap().into_int64();
if count == 0 {
result.set_bit(i);
}
}
result
}
}

pub trait Bits: Default {
const BITS: usize;
fn get_bit(&self, i: usize) -> bool;
fn set_bit(&mut self, i: usize);
}

impl Bits for i16 {
const BITS: usize = 16;

fn get_bit(&self, i: usize) -> bool {
(*self >> i) & 1 == 1
}

fn set_bit(&mut self, i: usize) {
*self |= 1 << i;
}
}

impl Bits for i32 {
const BITS: usize = 32;

fn get_bit(&self, i: usize) -> bool {
(*self >> i) & 1 == 1
}

fn set_bit(&mut self, i: usize) {
*self |= 1 << i;
}
}

impl Bits for i64 {
const BITS: usize = 64;

fn get_bit(&self, i: usize) -> bool {
(*self >> i) & 1 == 1
}

fn set_bit(&mut self, i: usize) {
*self |= 1 << i;
}
}
Loading

0 comments on commit 80f1d58

Please sign in to comment.