Skip to content

Commit

Permalink
fix all build
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed Sep 13, 2023
1 parent ebcc575 commit 86f44c6
Show file tree
Hide file tree
Showing 18 changed files with 337 additions and 440 deletions.
60 changes: 34 additions & 26 deletions src/expr/benches/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::*;
use risingwave_expr::agg::{build as build_agg, AggArgs, AggCall, AggKind};
use risingwave_expr::expr::*;
use risingwave_expr::sig::agg::agg_func_sigs;
use risingwave_expr::sig::func::func_sigs;
use risingwave_expr::sig::func_sigs;
use risingwave_expr::ExprError;
use risingwave_pb::expr::expr_node::PbType;

Expand Down Expand Up @@ -227,10 +226,10 @@ fn bench_expr(c: &mut Criterion) {
InputRefExpression::new(DataType::Jsonb, 26),
InputRefExpression::new(DataType::Int256, 27),
];
let input_index_for_type = |ty: DataType| {
let input_index_for_type = |ty: &DataType| {
inputrefs
.iter()
.find(|r| r.return_type() == ty)
.find(|r| &r.return_type() == ty)
.unwrap_or_else(|| panic!("expression not found for {ty:?}"))
.index()
};
Expand Down Expand Up @@ -265,19 +264,20 @@ fn bench_expr(c: &mut Criterion) {
c.bench_function("extract(constant)", |bencher| {
let extract = build_from_pretty(format!(
"(extract:decimal HOUR:varchar ${}:timestamp)",
input_index_for_type(DataType::Timestamp)
input_index_for_type(&DataType::Timestamp)
));
bencher
.to_async(FuturesExecutor)
.iter(|| extract.eval(&input))
});

let sigs = func_sigs();
let sigs = sigs.sorted_by_cached_key(|sig| format!("{sig:?}"));
let sigs = func_sigs()
.filter(|s| s.is_scalar())
.sorted_by_cached_key(|sig| format!("{sig:?}"));
'sig: for sig in sigs {
if (sig.inputs_type.iter())
.chain(&[sig.ret_type])
.any(|t| matches!(t, DataTypeName::Struct | DataTypeName::List))
.chain([&sig.ret_type])
.any(|t| !t.is_exact())
{
// TODO: support struct and list
println!("todo: {sig:?}");
Expand All @@ -300,8 +300,8 @@ fn bench_expr(c: &mut Criterion) {

let mut children = vec![];
for (i, t) in sig.inputs_type.iter().enumerate() {
use DataTypeName::*;
let idx = match (sig.func, i) {
use DataType::*;
let idx = match (sig.name.as_scalar(), i) {
(PbType::ToTimestamp1, 0) => TIMESTAMP_FORMATTED_STRING,
(PbType::ToChar | PbType::ToTimestamp1, 1) => {
children.push(string_literal("YYYY/MM/DD HH:MM:SS"));
Expand All @@ -315,7 +315,7 @@ fn bench_expr(c: &mut Criterion) {
children.push(string_literal("VALUE"));
continue;
}
(PbType::Cast, 0) if *t == DataTypeName::Varchar => match sig.ret_type {
(PbType::Cast, 0) if t.as_exact() == &Varchar => match sig.ret_type.as_exact() {
Boolean => BOOL_STRING,
Int16 | Int32 | Int64 | Float32 | Float64 | Decimal => NUMBER_STRING,
Date => DATE_STRING,
Expand All @@ -332,46 +332,54 @@ fn bench_expr(c: &mut Criterion) {
(PbType::AtTimeZone, 1) => TIMEZONE,
(PbType::DateTrunc, 0) => TIME_FIELD,
(PbType::DateTrunc, 2) => TIMEZONE,
(PbType::Extract, 0) => match sig.inputs_type[1] {
(PbType::Extract, 0) => match sig.inputs_type[1].as_exact() {
Date => EXTRACT_FIELD_DATE,
Time => EXTRACT_FIELD_TIME,
Timestamp => EXTRACT_FIELD_TIMESTAMP,
Timestamptz => EXTRACT_FIELD_TIMESTAMPTZ,
Interval => EXTRACT_FIELD_INTERVAL,
t => panic!("unexpected type: {t:?}"),
},
_ => input_index_for_type((*t).into()),
_ => input_index_for_type(t.as_exact()),
};
children.push(InputRefExpression::new(DataType::from(*t), idx).boxed());
children.push(InputRefExpression::new(t.as_exact().clone(), idx).boxed());
}
let expr = build_func(sig.func, sig.ret_type.into(), children).unwrap();
let expr = build_func(
sig.name.as_scalar(),
sig.ret_type.as_exact().clone(),
children,
)
.unwrap();
c.bench_function(&format!("{sig:?}"), |bencher| {
bencher.to_async(FuturesExecutor).iter(|| expr.eval(&input))
});
}

let sigs = agg_func_sigs();
let sigs = sigs.sorted_by_cached_key(|sig| format!("{sig:?}"));
let sigs = func_sigs()
.filter(|s| s.is_aggregate())
.sorted_by_cached_key(|sig| format!("{sig:?}"));
for sig in sigs {
if matches!(sig.func, AggKind::PercentileDisc | AggKind::PercentileCont)
|| (sig.inputs_type.iter())
.chain(&[sig.ret_type])
.any(|t| matches!(t, DataTypeName::Struct | DataTypeName::List))
if matches!(
sig.name.as_aggregate(),
AggKind::PercentileDisc | AggKind::PercentileCont
) || (sig.inputs_type.iter())
.chain([&sig.ret_type])
.any(|t| !t.is_exact())
{
println!("todo: {sig:?}");
continue;
}
let agg = match build_agg(&AggCall {
kind: sig.func,
args: match sig.inputs_type {
kind: sig.name.as_aggregate(),
args: match sig.inputs_type.as_slice() {
[] => AggArgs::None,
[t] => AggArgs::Unary((*t).into(), input_index_for_type((*t).into())),
[t] => AggArgs::Unary(t.as_exact().clone(), input_index_for_type(t.as_exact())),
_ => {
println!("todo: {sig:?}");
continue;
}
},
return_type: sig.ret_type.into(),
return_type: sig.ret_type.as_exact().clone(),
column_orders: vec![],
filter: None,
distinct: false,
Expand Down
30 changes: 15 additions & 15 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ impl FunctionAttr {
if ty == "..." {
break;
}
args.push(match_type(ty));
args.push(sig_data_type(ty));
}
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let ret = match_type(&self.ret);
let variadic = matches!(self.args.last(), Some(s) if s == "...");
let ret = sig_data_type(&self.ret);

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
let ctor_name = format_ident!("{}", self.ident_name());
Expand All @@ -101,7 +101,7 @@ impl FunctionAttr {
#[ctor::ctor]
fn #ctor_name() {
use risingwave_common::types::{DataType, DataTypeName};
use crate::sig::{_register, FuncSign, MatchType, FuncBuilder};
use crate::sig::{_register, FuncSign, SigDataType, FuncBuilder};

unsafe { _register(FuncSign {
name: risingwave_pb::expr::expr_node::Type::#pb_type.into(),
Expand Down Expand Up @@ -491,9 +491,9 @@ impl FunctionAttr {

let mut args = Vec::with_capacity(self.args.len());
for ty in &self.args {
args.push(match_type(ty));
args.push(sig_data_type(ty));
}
let ret = match_type(&self.ret);
let ret = sig_data_type(&self.ret);

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
let ctor_name = format_ident!("{}", self.ident_name());
Expand All @@ -510,7 +510,7 @@ impl FunctionAttr {
#[ctor::ctor]
fn #ctor_name() {
use risingwave_common::types::{DataType, DataTypeName};
use crate::sig::{_register, FuncSign, MatchType, FuncBuilder};
use crate::sig::{_register, FuncSign, SigDataType, FuncBuilder};

unsafe { _register(FuncSign {
name: crate::agg::AggKind::#pb_type.into(),
Expand Down Expand Up @@ -724,9 +724,9 @@ impl FunctionAttr {
let name = self.name.clone();
let mut args = Vec::with_capacity(self.args.len());
for ty in &self.args {
args.push(match_type(ty));
args.push(sig_data_type(ty));
}
let ret = match_type(&self.ret);
let ret = sig_data_type(&self.ret);

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
let ctor_name = format_ident!("{}", self.ident_name());
Expand All @@ -743,7 +743,7 @@ impl FunctionAttr {
#[ctor::ctor]
fn #ctor_name() {
use risingwave_common::types::{DataType, DataTypeName};
use crate::sig::{_register, FuncSign, MatchType, FuncBuilder};
use crate::sig::{_register, FuncSign, SigDataType, FuncBuilder};

unsafe { _register(FuncSign {
name: risingwave_pb::expr::table_function::Type::#pb_type.into(),
Expand Down Expand Up @@ -934,14 +934,14 @@ impl FunctionAttr {
}
}

fn match_type(ty: &str) -> TokenStream2 {
fn sig_data_type(ty: &str) -> TokenStream2 {
match ty {
"any" => quote! { MatchType::Any },
"anyarray" => quote! { MatchType::AnyArray },
"struct" => quote! { MatchType::AnyStruct },
"any" => quote! { SigDataType::Any },
"anyarray" => quote! { SigDataType::AnyArray },
"struct" => quote! { SigDataType::AnyStruct },
_ => {
let datatype = data_type(ty);
quote! { MatchType::Exact(#datatype) }
quote! { SigDataType::Exact(#datatype) }
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/expr/src/agg/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use risingwave_common::array::ListValue;
use risingwave_common::types::{Datum, ScalarRefImpl, ToOwnedDatum};
use risingwave_expr_macro::aggregate;

#[aggregate("array_agg(any) -> anyarray")]
#[aggregate(
"array_agg(any) -> anyarray",
type_infer = "|args| Ok(DataType::List(Box::new(args[0].clone())))"
)]
fn array_agg(state: Option<ListValue>, value: Option<ScalarRefImpl<'_>>) -> ListValue {
let mut state: Vec<Datum> = state.unwrap_or_default().into();
state.push(value.to_owned_datum());
Expand Down
13 changes: 9 additions & 4 deletions src/expr/src/agg/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::convert::From;
use std::ops::{BitAnd, BitOr, BitXor};

use num_traits::{CheckedAdd, CheckedSub};
use risingwave_common::types::DataType;
use risingwave_expr_macro::aggregate;

use crate::{ExprError, Result};
Expand Down Expand Up @@ -50,12 +51,12 @@ where
Ok(Some(result))
}

#[aggregate("min(*) -> auto", state = "ref")]
#[aggregate("min(*) -> auto", state = "ref", type_infer = "same_as_arg0")]
fn min<T: Ord>(state: T, input: T) -> T {
state.min(input)
}

#[aggregate("max(*) -> auto", state = "ref")]
#[aggregate("max(*) -> auto", state = "ref", type_infer = "same_as_arg0")]
fn max<T: Ord>(state: T, input: T) -> T {
state.max(input)
}
Expand Down Expand Up @@ -84,16 +85,20 @@ where
state.bitxor(input)
}

#[aggregate("first_value(*) -> auto", state = "ref")]
#[aggregate("first_value(*) -> auto", state = "ref", type_infer = "same_as_arg0")]
fn first_value<T>(state: T, _: T) -> T {
state
}

#[aggregate("last_value(*) -> auto", state = "ref")]
#[aggregate("last_value(*) -> auto", state = "ref", type_infer = "same_as_arg0")]
fn last_value<T>(_: T, input: T) -> T {
input
}

fn same_as_arg0(args: &[DataType]) -> Result<DataType> {
Ok(args[0].clone())
}

/// Note the following corner cases:
///
/// ```slt
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/agg/mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use super::{AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFuncti
use crate::agg::AggCall;
use crate::Result;

#[build_aggregate("mode(any) -> any")]
#[build_aggregate("mode(any) -> any", type_infer = "|args| Ok(args[0].clone())")]
fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
Ok(Box::new(Mode {
return_type: agg.return_type.clone(),
Expand Down
5 changes: 4 additions & 1 deletion src/expr/src/agg/percentile_disc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ use crate::Result;
/// statement ok
/// drop table t;
/// ```
#[build_aggregate("percentile_disc(any) -> any")]
#[build_aggregate(
"percentile_disc(any) -> any",
type_infer = "|args| Ok(args[0].clone())"
)]
fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
let fractions = agg.direct_args[0]
.literal()
Expand Down
Loading

0 comments on commit 86f44c6

Please sign in to comment.