Skip to content

Commit

Permalink
store udf meta in protobuf
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc committed Sep 23, 2024
1 parent e84f109 commit ce07651
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 13 deletions.
17 changes: 16 additions & 1 deletion proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,20 @@ message AggCall {
ExprNode scalar = 9;
}

// The aggregation kind.
//
// Ideally this should be used to encode the Rust `AggCall::kind` field, but historically we
// flattened it into multiple fields in proto `AggCall`. So this `AggKind` proto type is only
// used by `WindowFunction`.
message AggKind {
AggCall.Type type = 1;

// UDF metadata. Only present when the type is `USER_DEFINED`.
optional UserDefinedFunctionMetadata udf_meta = 8;
// Wrapped scalar expression. Only present when the type is `WRAP_SCALAR`.
optional ExprNode scalar_expr = 9;
}

message WindowFrame {
enum Type {
TYPE_UNSPECIFIED = 0;
Expand Down Expand Up @@ -562,7 +576,8 @@ message WindowFunction {

oneof type {
GeneralType general = 1;
AggCall.Type aggregate = 2;
AggCall.Type aggregate_simple = 2 [deprecated = true]; // Deprecated since we have a new `aggregate` variant.
AggKind aggregate = 103;
}
repeated InputRef args = 3;
data.DataType return_type = 4;
Expand Down
41 changes: 36 additions & 5 deletions src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding::DatumFromProtoExt;
pub use risingwave_pb::expr::agg_call::PbType as PbAggType;
use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata};
use risingwave_pb::expr::{
PbAggCall, PbAggKind, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
};

use crate::expr::{
build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token,
Expand Down Expand Up @@ -65,7 +67,7 @@ pub struct AggCall {

impl AggCall {
pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
let agg_kind = AggKind::from_protobuf(
let agg_kind = AggKind::from_protobuf_flatten(
agg_call.get_type()?,
agg_call.udf.as_ref(),
agg_call.scalar.as_ref(),
Expand Down Expand Up @@ -160,7 +162,7 @@ impl<Iter: Iterator<Item = Token>> Parser<Iter> {
self.tokens.next(); // Consume the RParen

AggCall {
kind: AggKind::from_protobuf(func, None, None).unwrap(),
kind: AggKind::from_protobuf_flatten(func, None, None).unwrap(),
args: AggArgs {
data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
val_indices: children.iter().map(|(idx, _)| *idx).collect(),
Expand Down Expand Up @@ -260,7 +262,7 @@ impl From<PbAggType> for AggKind {
}

impl AggKind {
pub fn from_protobuf(
pub fn from_protobuf_flatten(
pb_type: PbAggType,
user_defined: Option<&PbUserDefinedFunctionMetadata>,
scalar: Option<&PbExprNode>,
Expand All @@ -279,13 +281,42 @@ impl AggKind {
}
}

pub fn to_protobuf(&self) -> PbAggType {
pub fn to_protobuf_simple(&self) -> PbAggType {
match self {
Self::Builtin(pb) => *pb,
Self::UserDefined(_) => PbAggType::UserDefined,
Self::WrapScalar(_) => PbAggType::WrapScalar,
}
}

pub fn from_protobuf(pb_kind: &PbAggKind) -> Result<Self> {
match pb_kind.r#type() {
PbAggType::Unspecified => bail!("Unrecognized agg."),
PbAggType::UserDefined => Ok(AggKind::UserDefined(pb_kind.get_udf_meta()?.clone())),
PbAggType::WrapScalar => Ok(AggKind::WrapScalar(pb_kind.get_scalar_expr()?.clone())),
pb_type => Ok(AggKind::Builtin(pb_type)),
}
}

pub fn to_protobuf(&self) -> PbAggKind {
match self {
Self::Builtin(ty) => PbAggKind {
r#type: *ty as _,
udf_meta: None,
scalar_expr: None,
},
Self::UserDefined(udf_meta) => PbAggKind {
r#type: PbAggType::UserDefined as _,
udf_meta: Some(udf_meta.clone()),
scalar_expr: None,
},
Self::WrapScalar(scalar_expr) => PbAggKind {
r#type: PbAggType::WrapScalar as _,
udf_meta: None,
scalar_expr: Some(scalar_expr.clone()),
},
}
}
}

/// Macros to generate match arms for [`AggKind`](AggKind).
Expand Down
8 changes: 5 additions & 3 deletions src/expr/core/src/window_function/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ impl WindowFuncKind {
Ok(PbGeneralType::Lead) => Self::Lead,
Err(_) => bail!("no such window function type"),
},
PbType::Aggregate(agg_type) => match PbAggType::try_from(*agg_type) {
// TODO(runji): support UDAF and wrapped scalar functions
Ok(agg_type) => Self::Aggregate(AggKind::from_protobuf(agg_type, None, None)?),
PbType::AggregateSimple(agg_type) => match PbAggType::try_from(*agg_type) {
Ok(agg_type) => {
Self::Aggregate(AggKind::from_protobuf_flatten(agg_type, None, None)?)
}
Err(_) => bail!("no such aggregate function type"),
},
PbType::Aggregate(agg_kind) => Self::Aggregate(AggKind::from_protobuf(agg_kind)?),
};
Ok(kind)
}
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/binder/expr/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ impl Binder {
None
};

let agg_kind = if let Some(wrapped_agg_kind) = wrapped_agg_kind {
Some(wrapped_agg_kind)
let agg_kind = if wrapped_agg_kind.is_some() {
wrapped_agg_kind
} else if let Some(ref udf) = udf
&& udf.kind.is_aggregate()
{
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/generic/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ impl PlanAggCall {

pub fn to_protobuf(&self) -> PbAggCall {
PbAggCall {
r#type: self.agg_kind.to_protobuf().into(),
r#type: self.agg_kind.to_protobuf_simple().into(),
return_type: Some(self.return_type.to_protobuf()),
args: self.inputs.iter().map(InputRef::to_proto).collect(),
distinct: self.distinct,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl PlanWindowFunction {
DenseRank => PbType::General(PbGeneralType::DenseRank as _),
Lag => PbType::General(PbGeneralType::Lag as _),
Lead => PbType::General(PbGeneralType::Lead as _),
Aggregate(agg_kind) => PbType::Aggregate(agg_kind.to_protobuf() as _),
Aggregate(agg_kind) => PbType::Aggregate(agg_kind.to_protobuf()),
};

PbWindowFunction {
Expand Down

0 comments on commit ce07651

Please sign in to comment.