From 731c8a45b8ab42807c4bfe79c03dafcae90b66cd Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 24 Oct 2023 16:02:56 +0800 Subject: [PATCH] optimize `jsonb_agg` and support any state in `#[aggregate]` Signed-off-by: Runji Wang --- Cargo.lock | 3 +- src/common/Cargo.toml | 2 +- src/common/src/types/jsonb.rs | 12 ++ src/expr/impl/Cargo.toml | 1 + src/expr/impl/src/aggregate/jsonb_agg.rs | 143 +++++++++++++++++++++-- src/expr/macro/src/gen.rs | 86 ++++++++------ src/expr/macro/src/lib.rs | 2 + src/expr/macro/src/parse.rs | 30 ++++- 8 files changed, 227 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 358078187355a..8ddbb5a7cbd91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4115,7 +4115,7 @@ dependencies = [ [[package]] name = "jsonbb" version = "0.1.0" -source = "git+https://github.com/risingwavelabs/jsonbb.git?rev=493fa2f#493fa2fe864c6b5789dc6014969f820679dfef09" +source = "git+https://github.com/risingwavelabs/jsonbb.git?rev=6ac9354#6ac93543309780c300fcecaec0e8269645cd3111" dependencies = [ "bytes", "serde", @@ -7433,6 +7433,7 @@ dependencies = [ "futures-util", "hex", "itertools 0.11.0", + "jsonbb", "madsim-tokio", "md5", "num-traits", diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 87110a82f5c18..bed04289a11ec 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -49,7 +49,7 @@ hyper = "0.14" hytra = { workspace = true } itertools = "0.11" itoa = "1.0" -jsonbb = { git = "https://github.com/risingwavelabs/jsonbb.git", rev = "493fa2f" } +jsonbb = { git = "https://github.com/risingwavelabs/jsonbb.git", rev = "6ac9354" } lru = { git = "https://github.com/risingwavelabs/lru-rs.git", rev = "cb2d7c7" } memcomparable = { version = "0.2", features = ["decimal"] } num-integer = "0.1" diff --git a/src/common/src/types/jsonb.rs b/src/common/src/types/jsonb.rs index e64a7f804ec43..e7c6d461fb46f 100644 --- a/src/common/src/types/jsonb.rs +++ b/src/common/src/types/jsonb.rs @@ -297,6 +297,18 @@ impl From> for JsonbVal { } } +impl From for JsonbVal { + fn from(v: Value) -> Self { + Self(v) + } +} + +impl<'a> From> for ValueRef<'a> { + fn from(v: JsonbRef<'a>) -> Self { + v.0 + } +} + impl<'a> JsonbRef<'a> { pub fn memcmp_serialize( &self, diff --git a/src/expr/impl/Cargo.toml b/src/expr/impl/Cargo.toml index 81cd685c4dc27..6fb9e4509dda0 100644 --- a/src/expr/impl/Cargo.toml +++ b/src/expr/impl/Cargo.toml @@ -29,6 +29,7 @@ futures-async-stream = { workspace = true } futures-util = "0.3" hex = "0.4" itertools = "0.11" +jsonbb = { git = "https://github.com/risingwavelabs/jsonbb.git", rev = "6ac9354" } md5 = "0.7" num-traits = "0.2" regex = "1" diff --git a/src/expr/impl/src/aggregate/jsonb_agg.rs b/src/expr/impl/src/aggregate/jsonb_agg.rs index 524280a76467e..96f5e50da85e3 100644 --- a/src/expr/impl/src/aggregate/jsonb_agg.rs +++ b/src/expr/impl/src/aggregate/jsonb_agg.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::types::JsonbVal; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::{JsonbRef, JsonbVal, ScalarImpl, F32, F64}; +use risingwave_expr::aggregate::AggStateDyn; use risingwave_expr::{aggregate, ExprError, Result}; #[aggregate("jsonb_agg(boolean) -> jsonb")] @@ -20,10 +22,11 @@ use risingwave_expr::{aggregate, ExprError, Result}; #[aggregate("jsonb_agg(*float) -> jsonb")] #[aggregate("jsonb_agg(varchar) -> jsonb")] #[aggregate("jsonb_agg(jsonb) -> jsonb")] -fn jsonb_agg(state: Option, input: Option>) -> JsonbVal { - let mut jsonb = state.unwrap_or_else(JsonbVal::empty_array); - jsonb.array_push(input.map_or_else(JsonbVal::null, Into::into)); - jsonb +fn jsonb_agg(state: &mut JsonbArrayState, input: Option) { + match input { + Some(input) => input.add_to(&mut state.0), + None => state.0.add_null(), + } } #[aggregate("jsonb_object_agg(varchar, boolean) -> jsonb")] @@ -32,12 +35,130 @@ fn jsonb_agg(state: Option, input: Option>) -> Jso #[aggregate("jsonb_object_agg(varchar, varchar) -> jsonb")] #[aggregate("jsonb_object_agg(varchar, jsonb) -> jsonb")] fn jsonb_object_agg( - state: Option, + state: &mut JsonbObjectState, key: Option<&str>, - value: Option>, -) -> Result { + value: Option, +) -> Result<()> { let key = key.ok_or(ExprError::FieldNameNull)?; - let mut jsonb = state.unwrap_or_else(JsonbVal::empty_object); - jsonb.object_insert(key, value.map_or_else(JsonbVal::null, Into::into)); - Ok(jsonb) + state.0.add_string(key); + match value { + Some(value) => value.add_to(&mut state.0), + None => state.0.add_null(), + } + Ok(()) +} + +#[derive(Debug)] +struct JsonbArrayState(jsonbb::Builder); + +impl EstimateSize for JsonbArrayState { + fn estimated_heap_size(&self) -> usize { + self.0.capacity() + } +} + +impl AggStateDyn for JsonbArrayState {} + +/// Creates an initial state. +impl Default for JsonbArrayState { + fn default() -> Self { + let mut builder = jsonbb::Builder::default(); + builder.begin_array(); + Self(builder) + } +} + +/// Finishes aggregation and returns the result. +impl From<&JsonbArrayState> for ScalarImpl { + fn from(builder: &JsonbArrayState) -> Self { + // TODO: avoid clone + let mut builder = builder.0.clone(); + builder.end_array(); + let jsonb: JsonbVal = builder.finish().into(); + jsonb.into() + } +} + +#[derive(Debug)] +struct JsonbObjectState(jsonbb::Builder); + +impl EstimateSize for JsonbObjectState { + fn estimated_heap_size(&self) -> usize { + self.0.capacity() + } +} + +impl AggStateDyn for JsonbObjectState {} + +/// Creates an initial state. +impl Default for JsonbObjectState { + fn default() -> Self { + let mut builder = jsonbb::Builder::default(); + builder.begin_object(); + Self(builder) + } +} + +/// Finishes aggregation and returns the result. +impl From<&JsonbObjectState> for ScalarImpl { + fn from(builder: &JsonbObjectState) -> Self { + // TODO: avoid clone + let mut builder = builder.0.clone(); + builder.end_object(); + let jsonb: JsonbVal = builder.finish().into(); + jsonb.into() + } +} + +/// Values that can be converted to JSON. +trait ToJson { + fn add_to(self, builder: &mut jsonbb::Builder); +} + +impl ToJson for bool { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_bool(self); + } +} + +impl ToJson for i16 { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_i64(self as _); + } +} + +impl ToJson for i32 { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_i64(self as _); + } +} + +impl ToJson for i64 { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_i64(self); + } +} + +impl ToJson for F32 { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_f64(self.0 as f64); + } +} + +impl ToJson for F64 { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_f64(self.0); + } +} + +impl ToJson for &str { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_string(self); + } +} + +impl ToJson for JsonbRef<'_> { + fn add_to(self, builder: &mut jsonbb::Builder) { + builder.add_value(self.into()); + } } diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 9155853df5b7b..454d2a3169137 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -579,9 +579,13 @@ impl FunctionAttr { /// Generate build function for aggregate function. fn generate_agg_build_fn(&self, user_fn: &AggregateFnOrImpl) -> Result { - let state_type: TokenStream2 = match &self.state { - Some(state) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(), - Some(state) if state != "ref" => types::owned_type(state).parse().unwrap(), + // If the first argument of the aggregate function is of type `&mut T`, + // we assume it is a user defined state type. + let custom_state = user_fn.accumulate().first_mut_ref_arg.as_ref(); + let state_type: TokenStream2 = match (custom_state, &self.state) { + (Some(s), _) => s.parse().unwrap(), + (_, Some(state)) if state == "ref" => types::ref_type(&self.ret).parse().unwrap(), + (_, Some(state)) if state != "ref" => types::owned_type(state).parse().unwrap(), _ => types::owned_type(&self.ret).parse().unwrap(), }; let let_arrays = self @@ -603,24 +607,37 @@ impl FunctionAttr { quote! { let #v = unsafe { #a.value_at_unchecked(row_id) }; } }) .collect_vec(); - let let_state = match &self.state { - Some(s) if s == "ref" => { - quote! { state0.as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()) } - } - _ => quote! { state0.take().map(|s| s.try_into().unwrap()) }, + let downcast_state = if custom_state.is_some() { + quote! { let mut state: &mut #state_type = state0.downcast_mut(); } + } else if let Some(s) = &self.state && s == "ref" { + quote! { let mut state: Option<#state_type> = state0.as_datum_mut().as_ref().map(|x| x.as_scalar_ref_impl().try_into().unwrap()); } + } else { + quote! { let mut state: Option<#state_type> = state0.as_datum_mut().take().map(|s| s.try_into().unwrap()); } }; - let assign_state = match &self.state { - Some(s) if s == "ref" => quote! { state.map(|x| x.to_owned_scalar().into()) }, - _ => quote! { state.map(|s| s.into()) }, + let restore_state = if custom_state.is_some() { + quote! {} + } else if let Some(s) = &self.state && s == "ref" { + quote! { *state0.as_datum_mut() = state.map(|x| x.to_owned_scalar().into()); } + } else { + quote! { *state0.as_datum_mut() = state.map(|s| s.into()); } }; - let create_state = self.init_state.as_ref().map(|state| { + let create_state = if custom_state.is_some() { + quote! { + fn create_state(&self) -> AggregateState { + AggregateState::Any(Box::<#state_type>::default()) + } + } + } else if let Some(state) = &self.init_state { let state: TokenStream2 = state.parse().unwrap(); quote! { fn create_state(&self) -> AggregateState { AggregateState::Datum(Some(#state.into())) } } - }); + } else { + // by default: `AggregateState::Datum(None)` + quote! {} + }; let args = (0..self.args.len()).map(|i| format_ident!("v{i}")); let args = quote! { #(#args,)* }; let panic_on_retract = { @@ -703,17 +720,23 @@ impl FunctionAttr { _ => todo!("multiple arguments are not supported for non-option function"), } } - let get_result = match user_fn { - AggregateFnOrImpl::Impl(impl_) if impl_.finalize.is_some() => { - quote! { - let state = match state { - Some(s) => s.as_scalar_ref_impl().try_into().unwrap(), - None => return Ok(None), - }; - Ok(Some(self.function.finalize(state).into())) - } + let update_state = if custom_state.is_some() { + quote! { _ = #next_state; } + } else { + quote! { state = #next_state; } + }; + let get_result = if custom_state.is_some() { + quote! { Ok(Some(state.downcast_ref::<#state_type>().into())) } + } else if let AggregateFnOrImpl::Impl(impl_) = user_fn && impl_.finalize.is_some() { + quote! { + let state = match state.as_datum() { + Some(s) => s.as_scalar_ref_impl().try_into().unwrap(), + None => return Ok(None), + }; + Ok(Some(self.function.finalize(state).into())) } - _ => quote! { Ok(state.clone()) }, + } else { + quote! { Ok(state.as_datum().clone()) } }; let function_field = match user_fn { AggregateFnOrImpl::Fn(_) => quote! {}, @@ -768,27 +791,25 @@ impl FunctionAttr { async fn update(&self, state0: &mut AggregateState, input: &StreamChunk) -> Result<()> { #(#let_arrays)* - let state0 = state0.as_datum_mut(); - let mut state: Option<#state_type> = #let_state; + #downcast_state for row_id in input.visibility().iter_ones() { let op = unsafe { *input.ops().get_unchecked(row_id) }; #(#let_values)* - state = #next_state; + #update_state } - *state0 = #assign_state; + #restore_state Ok(()) } async fn update_range(&self, state0: &mut AggregateState, input: &StreamChunk, range: Range) -> Result<()> { assert!(range.end <= input.capacity()); #(#let_arrays)* - let state0 = state0.as_datum_mut(); - let mut state: Option<#state_type> = #let_state; + #downcast_state if input.is_compacted() { for row_id in range { let op = unsafe { *input.ops().get_unchecked(row_id) }; #(#let_values)* - state = #next_state; + #update_state } } else { for row_id in input.visibility().iter_ones() { @@ -799,15 +820,14 @@ impl FunctionAttr { } let op = unsafe { *input.ops().get_unchecked(row_id) }; #(#let_values)* - state = #next_state; + #update_state } } - *state0 = #assign_state; + #restore_state Ok(()) } async fn get_result(&self, state: &AggregateState) -> Result { - let state = state.as_datum(); #get_result } } diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index 363fc958b557d..50a99cf3fda22 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -522,6 +522,8 @@ struct UserFunctionAttr { retract: bool, /// The argument type are `Option`s. arg_option: bool, + /// If the first argument type is `&mut T`, then `Some(T)`. + first_mut_ref_arg: Option, /// The return type kind. return_type_kind: ReturnTypeKind, /// The kind of inner type `T` in `impl Iterator` diff --git a/src/expr/macro/src/parse.rs b/src/expr/macro/src/parse.rs index 24cc6942afcee..8e2e8c6d0b2f1 100644 --- a/src/expr/macro/src/parse.rs +++ b/src/expr/macro/src/parse.rs @@ -123,6 +123,7 @@ impl From<&syn::Signature> for UserFunctionAttr { context: sig.inputs.iter().any(arg_is_context), retract: last_arg_is_retract(sig), arg_option: args_contain_option(sig), + first_mut_ref_arg: first_mut_ref_arg(sig), return_type_kind, iterator_item_kind, core_return_type, @@ -223,18 +224,15 @@ fn last_arg_is_retract(sig: &syn::Signature) -> bool { /// Check if any argument is `Option`. fn args_contain_option(sig: &syn::Signature) -> bool { - if sig.inputs.is_empty() { - return false; - } for arg in &sig.inputs { let syn::FnArg::Typed(arg) = arg else { - return false; + continue; }; let syn::Type::Path(path) = arg.ty.as_ref() else { - return false; + continue; }; let Some(seg) = path.path.segments.last() else { - return false; + continue; }; if seg.ident == "Option" { return true; @@ -243,6 +241,26 @@ fn args_contain_option(sig: &syn::Signature) -> bool { false } +/// Returns `T` if the first argument (except `self`) is `&mut T`. +fn first_mut_ref_arg(sig: &syn::Signature) -> Option { + let arg = match sig.inputs.first()? { + syn::FnArg::Typed(arg) => arg, + syn::FnArg::Receiver(_) => match sig.inputs.iter().nth(1)? { + syn::FnArg::Typed(arg) => arg, + _ => return None, + }, + }; + let syn::Type::Reference(syn::TypeReference { + elem, + mutability: Some(_), + .. + }) = arg.ty.as_ref() + else { + return None; + }; + Some(elem.to_token_stream().to_string()) +} + /// Check the return type. fn check_type(ty: &syn::Type) -> (ReturnTypeKind, &syn::Type) { if let Some(inner) = strip_outer_type(ty, "Result") {