Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agg): fix embedded UDAF as window function #18632

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,69 @@ select t.value, sum00(weight) OVER (PARTITION BY value) from (values (1, 1), (nu
----
1 1
3 3

statement ok
drop aggregate sum00;

# https://github.com/risingwavelabs/risingwave/issues/18436

statement ok
CREATE TABLE exam_scores (
score_id int,
exam_id int,
student_id int,
score real,
exam_date timestamp
);

statement ok
INSERT INTO exam_scores (score_id, exam_id, student_id, score, exam_date)
VALUES
(1, 101, 1001, 85.5, '2022-01-10'),
(2, 101, 1002, 92.0, '2022-01-10'),
(3, 101, 1003, 78.5, '2022-01-10'),
(4, 102, 1001, 91.2, '2022-02-15'),
(5, 102, 1003, 88.9, '2022-02-15');

statement ok
create aggregate weighted_avg(value float, weight float) returns float language python as $$
def create_state():
return (0, 0)
def accumulate(state, value, weight):
if value is None or weight is None:
return state
(s, w) = state
s += value * weight
w += weight
return (s, w)
def retract(state, value, weight):
if value is None or weight is None:
return state
(s, w) = state
s -= value * weight
w -= weight
return (s, w)
def finish(state):
(sum, weight) = state
if weight == 0:
return None
else:
return sum / weight
$$;

query
SELECT
*,
weighted_avg(score, 1) OVER (
PARTITION BY "student_id"
ORDER BY "exam_date"
ROWS 2 PRECEDING
) AS "weighted_avg"
FROM exam_scores
ORDER BY "student_id", "exam_date";
----
1 101 1001 85.5 2022-01-10 00:00:00 85.5
4 102 1001 91.2 2022-02-15 00:00:00 88.3499984741211
2 101 1002 92 2022-01-10 00:00:00 92
3 101 1003 78.5 2022-01-10 00:00:00 78.5
5 102 1003 88.9 2022-02-15 00:00:00 83.70000076293945
17 changes: 16 additions & 1 deletion proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,20 @@ message AggCall {
ExprNode scalar = 9;
}

// The aggregation type.
//
// Ideally this should be used to encode the Rust `AggCall::agg_type` field, but historically we
// flattened it into multiple fields in proto `AggCall` - `kind` + `udf` + `scalar`. So this
// `AggType` proto type is only used by `WindowFunction` currently.
message AggType {
AggCall.Kind kind = 1;

// UDF metadata. Only present when the kind is `USER_DEFINED`.
optional UserDefinedFunctionMetadata udf_meta = 8;
// Wrapped scalar expression. Only present when the kind is `WRAP_SCALAR`.
optional ExprNode scalar_expr = 9;
}

message WindowFrame {
enum Type {
TYPE_UNSPECIFIED = 0;
Expand Down Expand Up @@ -562,7 +576,8 @@ message WindowFunction {

oneof type {
GeneralType general = 1;
AggCall.Kind aggregate = 2;
AggCall.Kind aggregate = 2 [deprecated = true]; // Deprecated since we have a new `aggregate2` variant.
AggType aggregate2 = 103;
}
repeated InputRef args = 3;
data.DataType return_type = 4;
Expand Down
39 changes: 35 additions & 4 deletions src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding::DatumFromProtoExt;
pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata};
use risingwave_pb::expr::{
PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
};

use crate::expr::{
build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token,
Expand Down Expand Up @@ -65,7 +67,7 @@ pub struct AggCall {

impl AggCall {
pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
let agg_type = AggType::from_protobuf(
let agg_type = AggType::from_protobuf_flatten(
agg_call.get_kind()?,
agg_call.udf.as_ref(),
agg_call.scalar.as_ref(),
Expand Down Expand Up @@ -160,7 +162,7 @@ impl<Iter: Iterator<Item = Token>> Parser<Iter> {
self.tokens.next(); // Consume the RParen

AggCall {
agg_type: AggType::from_protobuf(func, None, None).unwrap(),
agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
args: AggArgs {
data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
val_indices: children.iter().map(|(idx, _)| *idx).collect(),
Expand Down Expand Up @@ -260,7 +262,7 @@ impl From<PbAggKind> for AggType {
}

impl AggType {
pub fn from_protobuf(
pub fn from_protobuf_flatten(
pb_kind: PbAggKind,
user_defined: Option<&PbUserDefinedFunctionMetadata>,
scalar: Option<&PbExprNode>,
Expand All @@ -286,6 +288,35 @@ impl AggType {
Self::WrapScalar(_) => PbAggKind::WrapScalar,
}
}

pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
PbAggKind::Unspecified => bail!("Unrecognized agg."),
PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
kind => Ok(AggType::Builtin(kind)),
}
}

pub fn to_protobuf(&self) -> PbAggType {
match self {
Self::Builtin(kind) => PbAggType {
kind: *kind as _,
udf_meta: None,
scalar_expr: None,
},
Self::UserDefined(udf_meta) => PbAggType {
kind: PbAggKind::UserDefined as _,
udf_meta: Some(udf_meta.clone()),
scalar_expr: None,
},
Self::WrapScalar(scalar_expr) => PbAggType {
kind: PbAggKind::WrapScalar as _,
udf_meta: None,
scalar_expr: Some(scalar_expr.clone()),
},
}
}
}

/// Macros to generate match arms for `AggType`.
Expand Down
12 changes: 7 additions & 5 deletions src/expr/core/src/window_function/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Context;
use parse_display::{Display, FromStr};
use risingwave_common::bail;

Expand Down Expand Up @@ -51,11 +52,12 @@ impl WindowFuncKind {
Ok(PbGeneralType::Lead) => Self::Lead,
Err(_) => bail!("no such window function type"),
},
PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) {
// TODO(runji): support UDAF and wrapped scalar functions
Ok(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type, None, None)?),
Err(_) => bail!("no such aggregate function type"),
},
PbType::Aggregate(kind) => Self::Aggregate(AggType::from_protobuf_flatten(
PbAggKind::try_from(*kind).context("no such aggregate function type")?,
None,
None,
)?),
PbType::Aggregate2(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type)?),
};
Ok(kind)
}
Expand Down
44 changes: 30 additions & 14 deletions src/expr/impl/src/window_function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::{bail, must_match};
use risingwave_common_estimate_size::{EstimateSize, KvSize};
use risingwave_expr::aggregate::{
AggCall, AggType, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
build_append_only, AggCall, AggType, AggregateFunction, AggregateState as AggImplState,
BoxedAggregateFunction,
};
use risingwave_expr::sig::FUNCTION_REGISTRY;
use risingwave_expr::window_function::{
Expand Down Expand Up @@ -63,19 +64,34 @@ pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
distinct: false,
direct_args: vec![],
};
// TODO(runji): support UDAF and wrapped scalar function
let agg_kind = must_match!(agg_type, AggType::Builtin(agg_kind) => agg_kind);
let agg_func_sig = FUNCTION_REGISTRY
.get(*agg_kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state()?;
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};

let (agg_func, agg_impl, enable_delta) = match agg_type {
AggType::Builtin(kind) => {
let agg_func_sig = FUNCTION_REGISTRY
.get(*kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state()?;
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};
(agg_func, agg_impl, enable_delta)
}
AggType::UserDefined(_) => {
// TODO(rc): utilize `retract` method of embedded UDAF to do incremental aggregation
let agg_func = build_append_only(&agg_call)?;
(agg_func, AggImpl::Full, false)
}
AggType::WrapScalar(_) => {
// we have to feed the wrapped scalar function with all the rows in the window,
// instead of doing incremental aggregation
let agg_func = build_append_only(&agg_call)?;
(agg_func, AggImpl::Full, false)
}
};

let this = match &call.frame.bounds {
FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/binder/expr/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ impl Binder {
None
};

let agg_type = if let Some(wrapped_agg_type) = wrapped_agg_type {
Some(wrapped_agg_type)
let agg_type = if wrapped_agg_type.is_some() {
wrapped_agg_type
} else if let Some(ref udf) = udf
&& udf.kind.is_aggregate()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl PlanWindowFunction {
DenseRank => PbType::General(PbGeneralType::DenseRank as _),
Lag => PbType::General(PbGeneralType::Lag as _),
Lead => PbType::General(PbGeneralType::Lead as _),
Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf_simple() as _),
Aggregate(agg_type) => PbType::Aggregate2(agg_type.to_protobuf()),
};

PbWindowFunction {
Expand Down
Loading