Skip to content

Commit

Permalink
optimize jsonb_agg and support any state in #[aggregate]
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed Oct 24, 2023
1 parent d1684bd commit 731c8a4
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 52 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ hyper = "0.14"
hytra = { workspace = true }
itertools = "0.11"
itoa = "1.0"
jsonbb = { git = "https://github.com/risingwavelabs/jsonbb.git", rev = "493fa2f" }
jsonbb = { git = "https://github.com/risingwavelabs/jsonbb.git", rev = "6ac9354" }
lru = { git = "https://github.com/risingwavelabs/lru-rs.git", rev = "cb2d7c7" }
memcomparable = { version = "0.2", features = ["decimal"] }
num-integer = "0.1"
Expand Down
12 changes: 12 additions & 0 deletions src/common/src/types/jsonb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,18 @@ impl From<JsonbRef<'_>> for JsonbVal {
}
}

impl From<Value> for JsonbVal {
fn from(v: Value) -> Self {
Self(v)
}
}

impl<'a> From<JsonbRef<'a>> for ValueRef<'a> {
fn from(v: JsonbRef<'a>) -> Self {
v.0
}
}

impl<'a> JsonbRef<'a> {
pub fn memcmp_serialize(
&self,
Expand Down
1 change: 1 addition & 0 deletions src/expr/impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ futures-async-stream = { workspace = true }
futures-util = "0.3"
hex = "0.4"
itertools = "0.11"
jsonbb = { git = "https://github.com/risingwavelabs/jsonbb.git", rev = "6ac9354" }
md5 = "0.7"
num-traits = "0.2"
regex = "1"
Expand Down
143 changes: 132 additions & 11 deletions src/expr/impl/src/aggregate/jsonb_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::types::JsonbVal;
use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::types::{JsonbRef, JsonbVal, ScalarImpl, F32, F64};
use risingwave_expr::aggregate::AggStateDyn;
use risingwave_expr::{aggregate, ExprError, Result};

#[aggregate("jsonb_agg(boolean) -> jsonb")]
#[aggregate("jsonb_agg(*int) -> jsonb")]
#[aggregate("jsonb_agg(*float) -> jsonb")]
#[aggregate("jsonb_agg(varchar) -> jsonb")]
#[aggregate("jsonb_agg(jsonb) -> jsonb")]
fn jsonb_agg(state: Option<JsonbVal>, input: Option<impl Into<JsonbVal>>) -> JsonbVal {
let mut jsonb = state.unwrap_or_else(JsonbVal::empty_array);
jsonb.array_push(input.map_or_else(JsonbVal::null, Into::into));
jsonb
fn jsonb_agg(state: &mut JsonbArrayState, input: Option<impl ToJson>) {
match input {
Some(input) => input.add_to(&mut state.0),
None => state.0.add_null(),
}
}

#[aggregate("jsonb_object_agg(varchar, boolean) -> jsonb")]
Expand All @@ -32,12 +35,130 @@ fn jsonb_agg(state: Option<JsonbVal>, input: Option<impl Into<JsonbVal>>) -> Jso
#[aggregate("jsonb_object_agg(varchar, varchar) -> jsonb")]
#[aggregate("jsonb_object_agg(varchar, jsonb) -> jsonb")]
fn jsonb_object_agg(
state: Option<JsonbVal>,
state: &mut JsonbObjectState,
key: Option<&str>,
value: Option<impl Into<JsonbVal>>,
) -> Result<JsonbVal> {
value: Option<impl ToJson>,
) -> Result<()> {
let key = key.ok_or(ExprError::FieldNameNull)?;
let mut jsonb = state.unwrap_or_else(JsonbVal::empty_object);
jsonb.object_insert(key, value.map_or_else(JsonbVal::null, Into::into));
Ok(jsonb)
state.0.add_string(key);
match value {
Some(value) => value.add_to(&mut state.0),
None => state.0.add_null(),
}
Ok(())
}

#[derive(Debug)]
struct JsonbArrayState(jsonbb::Builder);

impl EstimateSize for JsonbArrayState {
fn estimated_heap_size(&self) -> usize {
self.0.capacity()
}
}

impl AggStateDyn for JsonbArrayState {}

/// Creates an initial state.
impl Default for JsonbArrayState {
fn default() -> Self {
let mut builder = jsonbb::Builder::default();
builder.begin_array();
Self(builder)
}
}

/// Finishes aggregation and returns the result.
impl From<&JsonbArrayState> for ScalarImpl {
fn from(builder: &JsonbArrayState) -> Self {
// TODO: avoid clone
let mut builder = builder.0.clone();
builder.end_array();
let jsonb: JsonbVal = builder.finish().into();
jsonb.into()
}
}

#[derive(Debug)]
struct JsonbObjectState(jsonbb::Builder);

impl EstimateSize for JsonbObjectState {
fn estimated_heap_size(&self) -> usize {
self.0.capacity()
}
}

impl AggStateDyn for JsonbObjectState {}

/// Creates an initial state.
impl Default for JsonbObjectState {
fn default() -> Self {
let mut builder = jsonbb::Builder::default();
builder.begin_object();
Self(builder)
}
}

/// Finishes aggregation and returns the result.
impl From<&JsonbObjectState> for ScalarImpl {
fn from(builder: &JsonbObjectState) -> Self {
// TODO: avoid clone
let mut builder = builder.0.clone();
builder.end_object();
let jsonb: JsonbVal = builder.finish().into();
jsonb.into()
}
}

/// Values that can be converted to JSON.
trait ToJson {
fn add_to(self, builder: &mut jsonbb::Builder);
}

impl ToJson for bool {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_bool(self);
}
}

impl ToJson for i16 {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_i64(self as _);
}
}

impl ToJson for i32 {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_i64(self as _);
}
}

impl ToJson for i64 {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_i64(self);
}
}

impl ToJson for F32 {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_f64(self.0 as f64);
}
}

impl ToJson for F64 {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_f64(self.0);
}
}

impl ToJson for &str {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_string(self);
}
}

impl ToJson for JsonbRef<'_> {
fn add_to(self, builder: &mut jsonbb::Builder) {
builder.add_value(self.into());
}
}
86 changes: 53 additions & 33 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,13 @@ impl FunctionAttr {

/// Generate build function for aggregate function.
fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result<TokenStream2> {
let state_type: TokenStream2 = match &self.state {
Some(state) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
Some(state) if state != "ref" => types::owned_type(state).parse().unwrap(),
// If the first argument of the aggregate function is of type `&mut T`,
// we assume it is a user defined state type.
let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref();
let state_type: TokenStream2 = match (custom_state, &self.state) {
(Some(s), _) => s.parse().unwrap(),
(_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(),
(_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(),
_ => types::owned_type(&self.ret).parse().unwrap(),
};
let let_arrays = self
Expand All @@ -603,24 +607,37 @@ impl FunctionAttr {
quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; }
})
.collect_vec();
let let_state = match &self.state {
Some(s) if s == "ref" => {
quote! { state0.as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()) }
}
_ => quote! { state0.take().map(|s| s.try_into().unwrap()) },
let downcast_state = if custom_state.is_some() {
quote! { let mut state: &mut #state_type = state0.downcast_mut(); }
} else if let Some(s) = &self.state && s == "ref" {
quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); }
} else {
quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); }
};
let assign_state = match &self.state {
Some(s) if s == "ref" => quote! { state.map(|x| x.to_owned_scalar().into()) },
_ => quote! { state.map(|s| s.into()) },
let restore_state = if custom_state.is_some() {
quote! {}
} else if let Some(s) = &self.state && s == "ref" {
quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); }
} else {
quote! { *state0.as_datum_mut() = state.map(|s| s.into()); }
};
let create_state = self.init_state.as_ref().map(|state| {
let create_state = if custom_state.is_some() {
quote! {
fn create_state(&self) -> AggregateState {
AggregateState::Any(Box::<#state_type>::default())
}
}
} else if let Some(state) = &self.init_state {
let state: TokenStream2 = state.parse().unwrap();
quote! {
fn create_state(&self) -> AggregateState {
AggregateState::Datum(Some(#state.into()))
}
}
});
} else {
// by default: `AggregateState::Datum(None)`
quote! {}
};
let args = (0..self.args.len()).map(|i| format_ident!("v{i}"));
let args = quote! { #(#args,)* };
let panic_on_retract = {
Expand Down Expand Up @@ -703,17 +720,23 @@ impl FunctionAttr {
_ => todo!("multiple arguments are not supported for non-option function"),
}
}
let get_result = match user_fn {
AggregateFnOrImpl::Impl(impl_) if impl_.finalize.is_some() => {
quote! {
let state = match state {
Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
None => return Ok(None),
};
Ok(Some(self.function.finalize(state).into()))
}
let update_state = if custom_state.is_some() {
quote! { _ = #next_state; }
} else {
quote! { state = #next_state; }
};
let get_result = if custom_state.is_some() {
quote! { Ok(Some(state.downcast_ref::<#state_type>().into())) }
} else if let AggregateFnOrImpl::Impl(impl_) = user_fn && impl_.finalize.is_some() {
quote! {
let state = match state.as_datum() {
Some(s) => s.as_scalar_ref_impl().try_into().unwrap(),
None => return Ok(None),
};
Ok(Some(self.function.finalize(state).into()))
}
_ => quote! { Ok(state.clone()) },
} else {
quote! { Ok(state.as_datum().clone()) }
};
let function_field = match user_fn {
AggregateFnOrImpl::Fn(_) => quote! {},
Expand Down Expand Up @@ -768,27 +791,25 @@ impl FunctionAttr {

async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> {
#(#let_arrays)*
let state0 = state0.as_datum_mut();
let mut state: Option<#state_type> = #let_state;
#downcast_state
for row_id in input.visibility().iter_ones() {
let op = unsafe { *input.ops().get_unchecked(row_id) };
#(#let_values)*
state = #next_state;
#update_state
}
*state0 = #assign_state;
#restore_state
Ok(())
}

async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range<usize>) -> Result<()> {
assert!(range.end <= input.capacity());
#(#let_arrays)*
let state0 = state0.as_datum_mut();
let mut state: Option<#state_type> = #let_state;
#downcast_state
if input.is_compacted() {
for row_id in range {
let op = unsafe { *input.ops().get_unchecked(row_id) };
#(#let_values)*
state = #next_state;
#update_state
}
} else {
for row_id in input.visibility().iter_ones() {
Expand All @@ -799,15 +820,14 @@ impl FunctionAttr {
}
let op = unsafe { *input.ops().get_unchecked(row_id) };
#(#let_values)*
state = #next_state;
#update_state
}
}
*state0 = #assign_state;
#restore_state
Ok(())
}

async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
let state = state.as_datum();
#get_result
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ struct UserFunctionAttr {
retract: bool,
/// The argument type are `Option`s.
arg_option: bool,
/// If the first argument type is `&mut T`, then `Some(T)`.
first_mut_ref_arg: Option<String>,
/// The return type kind.
return_type_kind: ReturnTypeKind,
/// The kind of inner type `T` in `impl Iterator<Item = T>`
Expand Down
Loading

0 comments on commit 731c8a4

Please sign in to comment.