Skip to content

Commit

Permalink
introduce proto AggType message
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc committed Sep 23, 2024
1 parent fbf4f06 commit 65550e6
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 10 deletions.
17 changes: 16 additions & 1 deletion proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,20 @@ message AggCall {
ExprNode scalar = 9;
}

// The aggregation type.
//
// Ideally this should be used to encode the Rust `AggCall::agg_type` field, but historically we
// flattened it into multiple fields in proto `AggCall` - `kind` + `udf` + `scalar`. So this
// `AggType` proto type is only used by `WindowFunction` currently.
message AggType {
AggCall.Kind kind = 1;

// UDF metadata. Only present when the kind is `USER_DEFINED`.
optional UserDefinedFunctionMetadata udf_meta = 8;
// Wrapped scalar expression. Only present when the kind is `WRAP_SCALAR`.
optional ExprNode scalar_expr = 9;
}

message WindowFrame {
enum Type {
TYPE_UNSPECIFIED = 0;
Expand Down Expand Up @@ -562,7 +576,8 @@ message WindowFunction {

oneof type {
GeneralType general = 1;
AggCall.Kind aggregate = 2;
AggCall.Kind aggregate_simple = 2 [deprecated = true]; // Deprecated since we have a new `aggregate` variant.
AggType aggregate = 103;
}
repeated InputRef args = 3;
data.DataType return_type = 4;
Expand Down
39 changes: 35 additions & 4 deletions src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding::DatumFromProtoExt;
pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata};
use risingwave_pb::expr::{
PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
};

use crate::expr::{
build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token,
Expand Down Expand Up @@ -65,7 +67,7 @@ pub struct AggCall {

impl AggCall {
pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
let agg_type = AggType::from_protobuf(
let agg_type = AggType::from_protobuf_flatten(
agg_call.get_kind()?,
agg_call.udf.as_ref(),
agg_call.scalar.as_ref(),
Expand Down Expand Up @@ -160,7 +162,7 @@ impl<Iter: Iterator<Item = Token>> Parser<Iter> {
self.tokens.next(); // Consume the RParen

AggCall {
agg_type: AggType::from_protobuf(func, None, None).unwrap(),
agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
args: AggArgs {
data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
val_indices: children.iter().map(|(idx, _)| *idx).collect(),
Expand Down Expand Up @@ -260,7 +262,7 @@ impl From<PbAggKind> for AggType {
}

impl AggType {
pub fn from_protobuf(
pub fn from_protobuf_flatten(
pb_kind: PbAggKind,
user_defined: Option<&PbUserDefinedFunctionMetadata>,
scalar: Option<&PbExprNode>,
Expand All @@ -286,6 +288,35 @@ impl AggType {
Self::WrapScalar(_) => PbAggKind::WrapScalar,
}
}

pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
match pb_type.kind() {
PbAggKind::Unspecified => bail!("Unrecognized agg."),
PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
kind => Ok(AggType::Builtin(kind)),
}
}

pub fn to_protobuf(&self) -> PbAggType {
match self {
Self::Builtin(kind) => PbAggType {
kind: *kind as _,
udf_meta: None,
scalar_expr: None,
},
Self::UserDefined(udf_meta) => PbAggType {
kind: PbAggKind::UserDefined as _,
udf_meta: Some(udf_meta.clone()),
scalar_expr: None,
},
Self::WrapScalar(scalar_expr) => PbAggType {
kind: PbAggKind::WrapScalar as _,
udf_meta: None,
scalar_expr: Some(scalar_expr.clone()),
},
}
}
}

/// Macros to generate match arms for `AggType`.
Expand Down
7 changes: 5 additions & 2 deletions src/expr/core/src/window_function/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@ impl WindowFuncKind {
Ok(PbGeneralType::Lead) => Self::Lead,
Err(_) => bail!("no such window function type"),
},
PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) {
PbType::AggregateSimple(agg_type) => match PbAggKind::try_from(*agg_type) {
// TODO(runji): support UDAF and wrapped scalar functions
Ok(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type, None, None)?),
Ok(agg_type) => {
Self::Aggregate(AggType::from_protobuf_flatten(agg_type, None, None)?)
}
Err(_) => bail!("no such aggregate function type"),
},
PbType::Aggregate(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type)?),
};
Ok(kind)
}
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/binder/expr/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ impl Binder {
None
};

let agg_type = if let Some(wrapped_agg_type) = wrapped_agg_type {
Some(wrapped_agg_type)
let agg_type = if wrapped_agg_type.is_some() {
wrapped_agg_type
} else if let Some(ref udf) = udf
&& udf.kind.is_aggregate()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl PlanWindowFunction {
DenseRank => PbType::General(PbGeneralType::DenseRank as _),
Lag => PbType::General(PbGeneralType::Lag as _),
Lead => PbType::General(PbGeneralType::Lead as _),
Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf_simple() as _),
Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf()),
};

PbWindowFunction {
Expand Down

0 comments on commit 65550e6

Please sign in to comment.