Skip to content

Commit

Permalink
feat(flow): avg func rewrite to sum/count (#3955)
Browse files Browse the repository at this point in the history
* feat(WIP): parse avg

* feat: RelationType::apply_mfp no need expr typs

* feat: avg&tests

* fix(WIP): avg eval

* fix: sum ret correct type

* chore: typos
  • Loading branch information
discord9 authored May 16, 2024
1 parent 9f4a6c6 commit 93f178f
Show file tree
Hide file tree
Showing 8 changed files with 495 additions and 64 deletions.
6 changes: 6 additions & 0 deletions src/flow/src/compute/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ mod test {
for now in time_range {
state.set_current_ts(now);
state.run_available_with_schedule(df);
if !state.get_err_collector().is_empty() {
panic!(
"Errors occur: {:?}",
state.get_err_collector().get_all_blocking()
)
}
assert!(state.get_err_collector().is_empty());
if let Some(expected) = expected.get(&now) {
assert_eq!(*output.borrow(), *expected, "at ts={}", now);
Expand Down
102 changes: 100 additions & 2 deletions src/flow/src/compute/render/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,15 +729,113 @@ mod test {
use std::cell::RefCell;
use std::rc::Rc;

use datatypes::data_type::ConcreteDataType;
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
use hydroflow::scheduled::graph::Hydroflow;

use super::*;
use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check};
use crate::compute::state::DataflowState;
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject};
use crate::expr::{self, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc};
use crate::repr::{ColumnType, RelationType};

/// select avg(number) from number;
#[test]
fn test_avg_eval() {
let mut df = Hydroflow::new();
let mut state = DataflowState::default();
let mut ctx = harness_test_ctx(&mut df, &mut state);

let rows = vec![
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
(Row::new(vec![1u32.into()]), 1, 1),
(Row::new(vec![2u32.into()]), 1, 1),
(Row::new(vec![3u32.into()]), 1, 1),
];
let collection = ctx.render_constant(rows.clone());
ctx.insert_global(GlobalId::User(1), collection);

let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(0),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0u32), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(ScalarExpr::Column(0).call_binary(
ScalarExpr::Column(1).call_unary(UnaryFunc::Cast(CDT::uint64_datatype())),
BinaryFunc::DivUInt64,
)),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::uint64_datatype())),
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::int64_datatype(), false),
])),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), true),
ColumnType::new(ConcreteDataType::int64_datatype(), true),
])),
),
mfp: MapFilterProject::new(2)
.map(vec![
avg_expr,
// TODO(discord9): optimize mfp so to remove indirect ref
ScalarExpr::Column(2),
])
.unwrap()
.project(vec![3])
.unwrap(),
},
};

let bundle = ctx.render_plan(expected).unwrap();

let output = get_output_handle(&mut ctx, bundle);
drop(ctx);
let expected = BTreeMap::from([(1, vec![(Row::new(vec![2u64.into()]), 1, 1)])]);
run_and_check(&mut state, &mut df, 1..2, expected, output);
}

/// SELECT DISTINCT col FROM table
///
/// table schema:
Expand Down
3 changes: 3 additions & 0 deletions src/flow/src/compute/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ pub struct ErrCollector {
}

impl ErrCollector {
pub fn get_all_blocking(&self) -> Vec<EvalError> {
self.inner.blocking_lock().drain(..).collect_vec()
}
pub async fn get_all(&self) -> Vec<EvalError> {
self.inner.lock().await.drain(..).collect_vec()
}
Expand Down
16 changes: 16 additions & 0 deletions src/flow/src/expr/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,22 @@ impl BinaryFunc {
)
}

pub fn add(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Add, input_type)
}

pub fn sub(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Sub, input_type)
}

pub fn mul(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Mul, input_type)
}

pub fn div(input_type: ConcreteDataType) -> Result<Self, Error> {
Self::specialization(GenericFn::Div, input_type)
}

/// Get the specialization of the binary function based on the generic function and the input type
pub fn specialization(generic: GenericFn, input_type: ConcreteDataType) -> Result<Self, Error> {
let rule = SPECIALIZATION.get_or_init(|| {
Expand Down
57 changes: 38 additions & 19 deletions src/flow/src/expr/relation/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,44 @@ impl AggregateFunc {

/// Generate signature for each aggregate function
macro_rules! generate_signature {
($value:ident, { $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($con_type:ident,$generic:ident)
),*
]) => {
($value:ident,
{ $($user_arm:tt)* },
[ $(
$auto_arm:ident=>($($arg:ident),*)
),*
]
) => {
match $value {
$($user_arm)*,
$(
Self::$auto_arm => Signature {
input: smallvec![
ConcreteDataType::$con_type(),
ConcreteDataType::$con_type(),
],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
},
Self::$auto_arm => gen_one_siginature!($($arg),*),
)*
}
};
}

/// Generate one match arm with optional arguments
macro_rules! gen_one_siginature {
(
$con_type:ident, $generic:ident
) => {
Signature {
input: smallvec![ConcreteDataType::$con_type(), ConcreteDataType::$con_type(),],
output: ConcreteDataType::$con_type(),
generic_fn: GenericFn::$generic,
}
};
(
$in_type:ident, $out_type:ident, $generic:ident
) => {
Signature {
input: smallvec![ConcreteDataType::$in_type()],
output: ConcreteDataType::$out_type(),
generic_fn: GenericFn::$generic,
}
};
}

static SPECIALIZATION: OnceLock<HashMap<(GenericFn, ConcreteDataType), AggregateFunc>> =
OnceLock::new();

Expand Down Expand Up @@ -223,6 +240,8 @@ impl AggregateFunc {

/// all concrete datatypes with precision types will be returned with largest possible variant
/// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64`
///
/// TODO(discorcd9): fix signature for sum unsign -> u64 sum signed -> i64
pub fn signature(&self) -> Signature {
generate_signature!(self, {
AggregateFunc::Count => Signature {
Expand Down Expand Up @@ -263,12 +282,12 @@ impl AggregateFunc {
MinTime => (time_second_datatype, Min),
MinDuration => (duration_second_datatype, Min),
MinInterval => (interval_year_month_datatype, Min),
SumInt16 => (int16_datatype, Sum),
SumInt32 => (int32_datatype, Sum),
SumInt64 => (int64_datatype, Sum),
SumUInt16 => (uint16_datatype, Sum),
SumUInt32 => (uint32_datatype, Sum),
SumUInt64 => (uint64_datatype, Sum),
SumInt16 => (int16_datatype, int64_datatype, Sum),
SumInt32 => (int32_datatype, int64_datatype, Sum),
SumInt64 => (int64_datatype, int64_datatype, Sum),
SumUInt16 => (uint16_datatype, uint64_datatype, Sum),
SumUInt32 => (uint32_datatype, uint64_datatype, Sum),
SumUInt64 => (uint64_datatype, uint64_datatype, Sum),
SumFloat32 => (float32_datatype, Sum),
SumFloat64 => (float64_datatype, Sum),
Any => (boolean_datatype, Any),
Expand Down
6 changes: 3 additions & 3 deletions src/flow/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct TypedPlan {
impl TypedPlan {
/// directly apply a mfp to the plan
pub fn mfp(self, mfp: MapFilterProject) -> Result<Self, Error> {
let new_type = self.typ.apply_mfp(&mfp, &[])?;
let new_type = self.typ.apply_mfp(&mfp)?;
let plan = match self.plan {
Plan::Mfp {
input,
Expand All @@ -68,14 +68,14 @@ impl TypedPlan {
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
let input_arity = self.typ.column_types.len();
let output_arity = exprs.len();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
let (exprs, _expr_typs): (Vec<_>, Vec<_>) = exprs
.into_iter()
.map(|TypedExpr { expr, typ }| (expr, typ))
.unzip();
let mfp = MapFilterProject::new(input_arity)
.map(exprs)?
.project(input_arity..input_arity + output_arity)?;
let out_typ = self.typ.apply_mfp(&mfp, &expr_typs)?;
let out_typ = self.typ.apply_mfp(&mfp)?;
// special case for mfp to compose when the plan is already mfp
let plan = match self.plan {
Plan::Mfp {
Expand Down
15 changes: 8 additions & 7 deletions src/flow/src/repr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ impl RelationType {
/// then new key=`[1]`, new time index=`[0]`
///
/// note that this function will remove empty keys like key=`[]` will be removed
pub fn apply_mfp(&self, mfp: &MapFilterProject, expr_typs: &[ColumnType]) -> Result<Self> {
let all_types = self
.column_types
.iter()
.chain(expr_typs.iter())
.cloned()
.collect_vec();
pub fn apply_mfp(&self, mfp: &MapFilterProject) -> Result<Self> {
let mut all_types = self.column_types.clone();
for expr in &mfp.expressions {
let expr_typ = expr.typ(&self.column_types)?;
all_types.push(expr_typ);
}
let all_types = all_types;
let mfp_out_types = mfp
.projection
.iter()
Expand All @@ -131,6 +131,7 @@ impl RelationType {
})
})
.try_collect()?;

let old_to_new_col = BTreeMap::from_iter(
mfp.projection
.clone()
Expand Down
Loading

0 comments on commit 93f178f

Please sign in to comment.