From ce0765100208987663f4b17ec48e15740ad0612b Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Fri, 20 Sep 2024 17:33:49 +0800 Subject: [PATCH] store udf meta in protobuf Signed-off-by: Richard Chien --- proto/expr.proto | 17 +++++++- src/expr/core/src/aggregate/def.rs | 41 ++++++++++++++++--- src/expr/core/src/window_function/kind.rs | 8 ++-- src/frontend/src/binder/expr/function/mod.rs | 4 +- .../src/optimizer/plan_node/generic/agg.rs | 2 +- .../plan_node/generic/over_window.rs | 2 +- 6 files changed, 61 insertions(+), 13 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 53bba96cc587b..7ddd6bdbd5f88 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -471,6 +471,20 @@ message AggCall { ExprNode scalar = 9; } +// The aggregation kind. +// +// Ideally this should be used to encode the Rust `AggCall::kind` field, but historically we +// flattened it into multiple fields in proto `AggCall`. So this `AggKind` proto type is only +// used by `WindowFunction`. +message AggKind { + AggCall.Type type = 1; + + // UDF metadata. Only present when the type is `USER_DEFINED`. + optional UserDefinedFunctionMetadata udf_meta = 8; + // Wrapped scalar expression. Only present when the type is `WRAP_SCALAR`. + optional ExprNode scalar_expr = 9; +} + message WindowFrame { enum Type { TYPE_UNSPECIFIED = 0; @@ -562,7 +576,8 @@ message WindowFunction { oneof type { GeneralType general = 1; - AggCall.Type aggregate = 2; + AggCall.Type aggregate_simple = 2 [deprecated = true]; // Deprecated since we have a new `aggregate` variant. + AggKind aggregate = 103; } repeated InputRef args = 3; data.DataType return_type = 4; diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index c32175594153d..e9d8b10b24da3 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::util::value_encoding::DatumFromProtoExt; pub use risingwave_pb::expr::agg_call::PbType as PbAggType; -use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata}; +use risingwave_pb::expr::{ + PbAggCall, PbAggKind, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata, +}; use crate::expr::{ build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, @@ -65,7 +67,7 @@ pub struct AggCall { impl AggCall { pub fn from_protobuf(agg_call: &PbAggCall) -> Result { - let agg_kind = AggKind::from_protobuf( + let agg_kind = AggKind::from_protobuf_flatten( agg_call.get_type()?, agg_call.udf.as_ref(), agg_call.scalar.as_ref(), @@ -160,7 +162,7 @@ impl> Parser { self.tokens.next(); // Consume the RParen AggCall { - kind: AggKind::from_protobuf(func, None, None).unwrap(), + kind: AggKind::from_protobuf_flatten(func, None, None).unwrap(), args: AggArgs { data_types: children.iter().map(|(_, ty)| ty.clone()).collect(), val_indices: children.iter().map(|(idx, _)| *idx).collect(), @@ -260,7 +262,7 @@ impl From for AggKind { } impl AggKind { - pub fn from_protobuf( + pub fn from_protobuf_flatten( pb_type: PbAggType, user_defined: Option<&PbUserDefinedFunctionMetadata>, scalar: Option<&PbExprNode>, @@ -279,13 +281,42 @@ impl AggKind { } } - pub fn to_protobuf(&self) -> PbAggType { + pub fn to_protobuf_simple(&self) -> PbAggType { match self { Self::Builtin(pb) => *pb, Self::UserDefined(_) => PbAggType::UserDefined, Self::WrapScalar(_) => PbAggType::WrapScalar, } } + + pub fn from_protobuf(pb_kind: &PbAggKind) -> Result { + match pb_kind.r#type() { + PbAggType::Unspecified => bail!("Unrecognized agg."), + PbAggType::UserDefined => Ok(AggKind::UserDefined(pb_kind.get_udf_meta()?.clone())), + PbAggType::WrapScalar => Ok(AggKind::WrapScalar(pb_kind.get_scalar_expr()?.clone())), + pb_type => Ok(AggKind::Builtin(pb_type)), + } + } + + pub fn to_protobuf(&self) -> PbAggKind { + match self { + Self::Builtin(ty) => PbAggKind { + r#type: *ty as _, + udf_meta: None, + scalar_expr: None, + }, + Self::UserDefined(udf_meta) => PbAggKind { + r#type: PbAggType::UserDefined as _, + udf_meta: Some(udf_meta.clone()), + scalar_expr: None, + }, + Self::WrapScalar(scalar_expr) => PbAggKind { + r#type: PbAggType::WrapScalar as _, + udf_meta: None, + scalar_expr: Some(scalar_expr.clone()), + }, + } + } } /// Macros to generate match arms for [`AggKind`](AggKind). diff --git a/src/expr/core/src/window_function/kind.rs b/src/expr/core/src/window_function/kind.rs index 3042facb5cffc..ebb956d273c23 100644 --- a/src/expr/core/src/window_function/kind.rs +++ b/src/expr/core/src/window_function/kind.rs @@ -51,11 +51,13 @@ impl WindowFuncKind { Ok(PbGeneralType::Lead) => Self::Lead, Err(_) => bail!("no such window function type"), }, - PbType::Aggregate(agg_type) => match PbAggType::try_from(*agg_type) { - // TODO(runji): support UDAF and wrapped scalar functions - Ok(agg_type) => Self::Aggregate(AggKind::from_protobuf(agg_type, None, None)?), + PbType::AggregateSimple(agg_type) => match PbAggType::try_from(*agg_type) { + Ok(agg_type) => { + Self::Aggregate(AggKind::from_protobuf_flatten(agg_type, None, None)?) + } Err(_) => bail!("no such aggregate function type"), }, + PbType::Aggregate(agg_kind) => Self::Aggregate(AggKind::from_protobuf(agg_kind)?), }; Ok(kind) } diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index 3505f14936c7d..d1b51c2215e6a 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -227,8 +227,8 @@ impl Binder { None }; - let agg_kind = if let Some(wrapped_agg_kind) = wrapped_agg_kind { - Some(wrapped_agg_kind) + let agg_kind = if wrapped_agg_kind.is_some() { + wrapped_agg_kind } else if let Some(ref udf) = udf && udf.kind.is_aggregate() { diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 02e1793b8db36..a7ade27de1361 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -780,7 +780,7 @@ impl PlanAggCall { pub fn to_protobuf(&self) -> PbAggCall { PbAggCall { - r#type: self.agg_kind.to_protobuf().into(), + r#type: self.agg_kind.to_protobuf_simple().into(), return_type: Some(self.return_type.to_protobuf()), args: self.inputs.iter().map(InputRef::to_proto).collect(), distinct: self.distinct, diff --git a/src/frontend/src/optimizer/plan_node/generic/over_window.rs b/src/frontend/src/optimizer/plan_node/generic/over_window.rs index c39fd99be9895..322d81bd691ab 100644 --- a/src/frontend/src/optimizer/plan_node/generic/over_window.rs +++ b/src/frontend/src/optimizer/plan_node/generic/over_window.rs @@ -121,7 +121,7 @@ impl PlanWindowFunction { DenseRank => PbType::General(PbGeneralType::DenseRank as _), Lag => PbType::General(PbGeneralType::Lag as _), Lead => PbType::General(PbGeneralType::Lead as _), - Aggregate(agg_kind) => PbType::Aggregate(agg_kind.to_protobuf() as _), + Aggregate(agg_kind) => PbType::Aggregate(agg_kind.to_protobuf()), }; PbWindowFunction {