From dbb1ce1a9bd8954c92feb4c6b1aad05d67f3bdf2 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Tue, 27 Feb 2024 16:13:55 +0800 Subject: [PATCH] feat(flow): impl for MapFilterProject (#3359) * feat: mfp impls * fix: after rebase * test: temporal filter mfp * refactor: more comments&test * test: permute * fix: check input len when eval * refactor: err handle&docs: more explain graph * docs: better flowchart map,filter,project * refactor: visit_* falliable * chore: better temp lint allow * fix: permute partially * chore: remove duplicated checks * docs: more explain&tests for clarity * refactor: use ensure! instead --- src/flow/src/expr/linear.rs | 833 +++++++++++++++++++++++++++++++++++- src/flow/src/expr/scalar.rs | 117 +++-- src/flow/src/lib.rs | 3 +- 3 files changed, 915 insertions(+), 38 deletions(-) diff --git a/src/flow/src/expr/linear.rs b/src/flow/src/expr/linear.rs index eddb98c49e1b..830195f5b2ec 100644 --- a/src/flow/src/expr/linear.rs +++ b/src/flow/src/expr/linear.rs @@ -16,18 +16,28 @@ use std::collections::{BTreeMap, BTreeSet}; use datatypes::value::Value; use serde::{Deserialize, Serialize}; +use snafu::{ensure, OptionExt}; use crate::expr::error::EvalError; -use crate::expr::{Id, LocalId, ScalarExpr}; +use crate::expr::{Id, InvalidArgumentSnafu, LocalId, ScalarExpr}; use crate::repr::{self, value_to_internal_ts, Diff, Row}; /// A compound operator that can be applied row-by-row. /// +/// In practice, this operator is a sequence of map, filter, and project in arbitrary order, +/// which can and is stored by reordering the sequence's +/// apply order to a `map` first, `filter` second and `project` third order. +/// +/// input is a row(a sequence of values), which is also being used for store intermediate results, +/// like `map` operator can append new columns to the row according to it's expressions, +/// `filter` operator decide whether this entire row can even be output by decide whether the row satisfy the predicates, +/// `project` operator decide which columns of the row should be output. +/// /// This operator integrates the map, filter, and project operators. /// It applies a sequences of map expressions, which are allowed to /// refer to previous expressions, interleaved with predicates which /// must be satisfied for an output to be produced. If all predicates -/// evaluate to `Datum::True` the data at the identified columns are +/// evaluate to `Value::Boolean(True)` the data at the identified columns are /// collected and produced as output in a packed `Row`. /// /// This operator is a "builder" and its contents may contain expressions @@ -48,8 +58,10 @@ pub struct MapFilterProject { /// Each entry is prepended with a column identifier indicating /// the column *before* which the predicate should first be applied. /// Most commonly this would be one plus the largest column identifier - /// in the predicate's support, but it could be larger to implement + /// in the predicate's referred columns, but it could be larger to implement /// guarded evaluation of predicates. + /// Put it in another word, the first element of the tuple means + /// the predicates can't be evaluated until that number of columns is formed. /// /// This list should be sorted by the first field. pub predicates: Vec<(usize, ScalarExpr)>, @@ -62,12 +74,447 @@ pub struct MapFilterProject { pub input_arity: usize, } +impl MapFilterProject { + /// Create a no-op operator for an input of a supplied arity. + pub fn new(input_arity: usize) -> Self { + Self { + expressions: Vec::new(), + predicates: Vec::new(), + projection: (0..input_arity).collect(), + input_arity, + } + } + + /// Given two mfps, return an mfp that applies one + /// followed by the other. + /// Note that the arguments are in the opposite order + /// from how function composition is usually written in mathematics. + pub fn compose(before: Self, after: Self) -> Result { + let (m, f, p) = after.into_map_filter_project(); + before.map(m)?.filter(f)?.project(p) + } + + /// True if the operator describes the identity transformation. + pub fn is_identity(&self) -> bool { + self.expressions.is_empty() + && self.predicates.is_empty() + // identity if projection is the identity permutation + && self.projection.len() == self.input_arity + && self.projection.iter().enumerate().all(|(i, p)| i == *p) + } + + /// Retain only the indicated columns in the presented order. + /// + /// i.e. before: `self.projection = [1, 2, 0], columns = [1, 0]` + /// ```mermaid + /// flowchart TD + /// col-0 + /// col-1 + /// col-2 + /// projection --> |0|col-1 + /// projection --> |1|col-2 + /// projection --> |2|col-0 + /// ``` + /// + /// after: `self.projection = [2, 1]` + /// ```mermaid + /// flowchart TD + /// col-0 + /// col-1 + /// col-2 + /// project("project:[1,2,0]") + /// project + /// project -->|0| col-1 + /// project -->|1| col-2 + /// project -->|2| col-0 + /// new_project("apply new project:[1,0]") + /// new_project -->|0| col-2 + /// new_project -->|1| col-1 + /// ``` + pub fn project(mut self, columns: I) -> Result + where + I: IntoIterator + std::fmt::Debug, + { + self.projection = columns + .into_iter() + .map(|c| self.projection.get(c).cloned().ok_or(c)) + .collect::, _>>() + .map_err(|c| { + InvalidArgumentSnafu { + reason: format!( + "column index {} out of range, expected at most {} columns", + c, + self.projection.len() + ), + } + .build() + })?; + Ok(self) + } + + /// Retain only rows satisfying these predicates. + /// + /// This method introduces predicates as eagerly as they can be evaluated, + /// which may not be desired for predicates that may cause exceptions. + /// If fine manipulation is required, the predicates can be added manually. + /// + /// simply added to the end of the predicates list + /// + /// while paying attention to column references maintained by `self.projection` + /// + /// so `self.projection = [1, 2, 0], filter = [0]+[1]>0`: + /// becomes: + /// ```mermaid + /// flowchart TD + /// col-0 + /// col-1 + /// col-2 + /// project("first project:[1,2,0]") + /// project + /// project -->|0| col-1 + /// project -->|1| col-2 + /// project -->|2| col-0 + /// filter("then filter:[0]+[1]>0") + /// filter -->|0| col-1 + /// filter --> |1| col-2 + /// ``` + pub fn filter(mut self, predicates: I) -> Result + where + I: IntoIterator, + { + for mut predicate in predicates { + // Correct column references. + predicate.permute(&self.projection[..])?; + + // Validate column references. + let referred_columns = predicate.get_all_ref_columns(); + for c in referred_columns.iter() { + // current row len include input columns and previous number of expressions + let cur_row_len = self.input_arity + self.expressions.len(); + ensure!( + *c < cur_row_len, + InvalidArgumentSnafu { + reason: format!( + "column index {} out of range, expected at most {} columns", + c, cur_row_len + ) + } + ); + } + + // Insert predicate as eagerly as it can be evaluated: + // just after the largest column in its support is formed. + let max_support = referred_columns + .into_iter() + .max() + .map(|c| c + 1) + .unwrap_or(0); + self.predicates.push((max_support, predicate)) + } + // Stable sort predicates by position at which they take effect. + self.predicates + .sort_by_key(|(position, _predicate)| *position); + Ok(self) + } + + /// Append the result of evaluating expressions to each row. + /// + /// simply append `expressions` to `self.expressions` + /// + /// while paying attention to column references maintained by `self.projection` + /// + /// hence, before apply map with a previously non-trivial projection would be like: + /// before: + /// ```mermaid + /// flowchart TD + /// col-0 + /// col-1 + /// col-2 + /// projection --> |0|col-1 + /// projection --> |1|col-2 + /// projection --> |2|col-0 + /// ``` + /// after apply map: + /// ```mermaid + /// flowchart TD + /// col-0 + /// col-1 + /// col-2 + /// project("project:[1,2,0]") + /// project + /// project -->|0| col-1 + /// project -->|1| col-2 + /// project -->|2| col-0 + /// map("map:[0]/[1]/[2]") + /// map -->|0|col-1 + /// map -->|1|col-2 + /// map -->|2|col-0 + /// ``` + pub fn map(mut self, expressions: I) -> Result + where + I: IntoIterator, + { + for mut expression in expressions { + // Correct column references. + expression.permute(&self.projection[..])?; + + // Validate column references. + for c in expression.get_all_ref_columns().into_iter() { + // current row len include input columns and previous number of expressions + let current_row_len = self.input_arity + self.expressions.len(); + ensure!( + c < current_row_len, + InvalidArgumentSnafu { + reason: format!( + "column index {} out of range, expected at most {} columns", + c, current_row_len + ) + } + ); + } + + // Introduce expression and produce as output. + self.expressions.push(expression); + // Expression by default is projected to output. + let cur_expr_col_num = self.input_arity + self.expressions.len() - 1; + self.projection.push(cur_expr_col_num); + } + + Ok(self) + } + + /// Like [`MapFilterProject::as_map_filter_project`], but consumes `self` rather than cloning. + pub fn into_map_filter_project(self) -> (Vec, Vec, Vec) { + let predicates = self + .predicates + .into_iter() + .map(|(_pos, predicate)| predicate) + .collect(); + (self.expressions, predicates, self.projection) + } + + /// As the arguments to `Map`, `Filter`, and `Project` operators. + /// + /// In principle, this operator can be implemented as a sequence of + /// more elemental operators, likely less efficiently. + pub fn as_map_filter_project(&self) -> (Vec, Vec, Vec) { + self.clone().into_map_filter_project() + } +} + +impl MapFilterProject { + pub fn optimize(&mut self) { + // TODO(discord9): optimize + } + + /// Convert the `MapFilterProject` into a staged evaluation plan. + /// + /// The main behavior is extract temporal predicates, which cannot be evaluated + /// using the standard machinery. + pub fn into_plan(self) -> Result { + MfpPlan::create_from(self) + } + + /// Lists input columns whose values are used in outputs. + /// + /// It is entirely appropriate to determine the demand of an instance + /// and then both apply a projection to the subject of the instance and + /// `self.permute` this instance. + pub fn demand(&self) -> BTreeSet { + let mut demanded = BTreeSet::new(); + // first, get all columns referenced by predicates + for (_index, pred) in self.predicates.iter() { + demanded.extend(pred.get_all_ref_columns()); + } + // then, get columns referenced by projection which is direct output + demanded.extend(self.projection.iter().cloned()); + + // check every expressions, if a expression is contained in demanded, then all columns it referenced should be added to demanded + for index in (0..self.expressions.len()).rev() { + if demanded.contains(&(self.input_arity + index)) { + demanded.extend(self.expressions[index].get_all_ref_columns()); + } + } + + // only keep demanded columns that are in input + demanded.retain(|col| col < &self.input_arity); + demanded + } + + /// Update input column references, due to an input projection or permutation. + /// + /// The `shuffle` argument remaps expected column identifiers to new locations, + /// with the expectation that `shuffle` describes all input columns, and so the + /// intermediate results will be able to start at position `shuffle.len()`. + /// + /// The supplied `shuffle` may not list columns that are not "demanded" by the + /// instance, and so we should ensure that `self` is optimized to not reference + /// columns that are not demanded. + pub fn permute( + &mut self, + mut shuffle: BTreeMap, + new_input_arity: usize, + ) -> Result<(), EvalError> { + // check shuffle is valid + let demand = self.demand(); + for d in demand { + ensure!( + shuffle.contains_key(&d), + InvalidArgumentSnafu { + reason: format!( + "Demanded column {} is not in shuffle's keys: {:?}", + d, + shuffle.keys() + ) + } + ); + } + ensure!( + shuffle.len() <= new_input_arity, + InvalidArgumentSnafu { + reason: format!( + "shuffle's length {} is greater than new_input_arity {}", + shuffle.len(), + self.input_arity + ) + } + ); + + // decompose self into map, filter, project for ease of manipulation + let (mut map, mut filter, mut project) = self.as_map_filter_project(); + for index in 0..map.len() { + // Intermediate columns are just shifted. + shuffle.insert(self.input_arity + index, new_input_arity + index); + } + + for expr in map.iter_mut() { + expr.permute_map(&shuffle)?; + } + for pred in filter.iter_mut() { + pred.permute_map(&shuffle)?; + } + let new_row_len = new_input_arity + map.len(); + for proj in project.iter_mut() { + ensure!( + shuffle[proj] < new_row_len, + InvalidArgumentSnafu { + reason: format!( + "shuffled column index {} out of range, expected at most {} columns", + shuffle[proj], new_row_len + ) + } + ); + *proj = shuffle[proj]; + } + *self = Self::new(new_input_arity) + .map(map)? + .filter(filter)? + .project(project)?; + Ok(()) + } +} + /// A wrapper type which indicates it is safe to simply evaluate all expressions. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct SafeMfpPlan { pub(crate) mfp: MapFilterProject, } +impl SafeMfpPlan { + /// See [`MapFilterProject::permute`]. + pub fn permute( + &mut self, + map: BTreeMap, + new_arity: usize, + ) -> Result<(), EvalError> { + self.mfp.permute(map, new_arity) + } + + /// Evaluates the linear operator on a supplied list of datums. + /// + /// The arguments are the initial datums associated with the row, + /// and an appropriately lifetimed arena for temporary allocations + /// needed by scalar evaluation. + /// + /// An `Ok` result will either be `None` if any predicate did not + /// evaluate to `Value::Boolean(true)`, or the values of the columns listed + /// by `self.projection` if all predicates passed. If an error + /// occurs in the evaluation it is returned as an `Err` variant. + /// As the evaluation exits early with failed predicates, it may + /// miss some errors that would occur later in evaluation. + /// + /// The `row` is not cleared first, but emptied if the function + /// returns `Ok(Some(row)). + #[inline(always)] + pub fn evaluate_into( + &self, + values: &mut Vec, + row_buf: &mut Row, + ) -> Result, EvalError> { + ensure!( + values.len() == self.mfp.input_arity, + InvalidArgumentSnafu { + reason: format!( + "values length {} is not equal to input_arity {}", + values.len(), + self.mfp.input_arity + ), + } + ); + let passed_predicates = self.evaluate_inner(values)?; + + if !passed_predicates { + Ok(None) + } else { + row_buf.clear(); + row_buf.extend(self.mfp.projection.iter().map(|c| values[*c].clone())); + Ok(Some(row_buf.clone())) + } + } + + /// A version of `evaluate` which produces an iterator over `Datum` + /// as output. + /// + /// This version can be useful when one wants to capture the resulting + /// datums without packing and then unpacking a row. + #[inline(always)] + pub fn evaluate_iter<'a>( + &'a self, + datums: &'a mut Vec, + ) -> Result + 'a>, EvalError> { + let passed_predicates = self.evaluate_inner(datums)?; + if !passed_predicates { + Ok(None) + } else { + Ok(Some( + self.mfp.projection.iter().map(move |i| datums[*i].clone()), + )) + } + } + + /// Populates `values` with `self.expressions` and tests `self.predicates`. + /// + /// This does not apply `self.projection`, which is up to the calling method. + pub fn evaluate_inner(&self, values: &mut Vec) -> Result { + let mut expression = 0; + for (support, predicate) in self.mfp.predicates.iter() { + while self.mfp.input_arity + expression < *support { + values.push(self.mfp.expressions[expression].eval(&values[..])?); + expression += 1; + } + if predicate.eval(&values[..])? != Value::Boolean(true) { + return Ok(false); + } + } + // while evaluated expressions are less than total expressions, keep evaluating + while expression < self.mfp.expressions.len() { + values.push(self.mfp.expressions[expression].eval(&values[..])?); + expression += 1; + } + Ok(true) + } +} + impl std::ops::Deref for SafeMfpPlan { type Target = MapFilterProject; fn deref(&self) -> &Self::Target { @@ -94,3 +541,383 @@ pub struct MfpPlan { /// Expressions that when evaluated upper-bound `MzNow`. pub(crate) upper_bounds: Vec, } + +impl MfpPlan { + /// find `now` in `predicates` and put them into lower/upper temporal bounds for temporal filter to use + pub fn create_from(mut mfp: MapFilterProject) -> Result { + let mut lower_bounds = Vec::new(); + let mut upper_bounds = Vec::new(); + + let mut temporal = Vec::new(); + + // Optimize, to ensure that temporal predicates are move in to `mfp.predicates`. + mfp.optimize(); + + mfp.predicates.retain(|(_position, predicate)| { + if predicate.contains_temporal() { + temporal.push(predicate.clone()); + false + } else { + true + } + }); + + for predicate in temporal { + let (lower, upper) = predicate.extract_bound()?; + lower_bounds.extend(lower); + upper_bounds.extend(upper); + } + Ok(Self { + mfp: SafeMfpPlan { mfp }, + lower_bounds, + upper_bounds, + }) + } + + /// Indicates if the planned `MapFilterProject` emits exactly its inputs as outputs. + pub fn is_identity(&self) -> bool { + self.mfp.mfp.is_identity() && self.lower_bounds.is_empty() && self.upper_bounds.is_empty() + } + + /// if `lower_bound <= sys_time < upper_bound`, return `[(data, sys_time, +1), (data, min_upper_bound, -1)]` + /// + /// else if `sys_time < lower_bound`, return `[(data, lower_bound, +1), (data, min_upper_bound, -1)]` + /// + /// else if `sys_time >= upper_bound`, return `[None, None]` + /// + /// if eval error appeal in any of those process, corresponding result will be `Err` + pub fn evaluate>( + &self, + values: &mut Vec, + sys_time: repr::Timestamp, + diff: Diff, + ) -> impl Iterator> + { + match self.mfp.evaluate_inner(values) { + Err(e) => { + return Some(Err((e.into(), sys_time, diff))) + .into_iter() + .chain(None); + } + Ok(true) => {} + Ok(false) => { + return None.into_iter().chain(None); + } + } + + let mut lower_bound = sys_time; + let mut upper_bound = None; + + // Track whether we have seen a null in either bound, as this should + // prevent the record from being produced at any time. + let mut null_eval = false; + let ret_err = |e: EvalError| { + Some(Err((e.into(), sys_time, diff))) + .into_iter() + .chain(None) + }; + for l in self.lower_bounds.iter() { + match l.eval(values) { + Ok(v) => { + if v.is_null() { + null_eval = true; + continue; + } + match value_to_internal_ts(v) { + Ok(ts) => lower_bound = lower_bound.max(ts), + Err(e) => return ret_err(e), + } + } + Err(e) => return ret_err(e), + }; + } + + for u in self.upper_bounds.iter() { + if upper_bound != Some(lower_bound) { + match u.eval(values) { + Err(e) => return ret_err(e), + Ok(val) => { + if val.is_null() { + null_eval = true; + continue; + } + let ts = match value_to_internal_ts(val) { + Ok(ts) => ts, + Err(e) => return ret_err(e), + }; + if let Some(upper) = upper_bound { + upper_bound = Some(upper.min(ts)); + } else { + upper_bound = Some(ts); + } + // Force the upper bound to be at least the lower + // bound. + if upper_bound.is_some() && upper_bound < Some(lower_bound) { + upper_bound = Some(lower_bound); + } + } + } + } + } + + if Some(lower_bound) != upper_bound && !null_eval { + let res_row = Row::pack(self.mfp.mfp.projection.iter().map(|c| values[*c].clone())); + let upper_opt = + upper_bound.map(|upper_bound| Ok((res_row.clone(), upper_bound, -diff))); + // if diff==-1, the `upper_opt` will cancel the future `-1` inserted before by previous diff==1 row + let lower = Some(Ok((res_row, lower_bound, diff))); + + lower.into_iter().chain(upper_opt) + } else { + None.into_iter().chain(None) + } + } +} + +#[cfg(test)] +mod test { + use datatypes::data_type::ConcreteDataType; + use itertools::Itertools; + + use super::*; + use crate::expr::{BinaryFunc, UnaryFunc, UnmaterializableFunc}; + #[test] + fn test_mfp_with_time() { + use crate::expr::func::BinaryFunc; + let lte_now = ScalarExpr::Column(0).call_binary( + ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now), + BinaryFunc::Lte, + ); + assert!(lte_now.contains_temporal()); + + let gt_now_minus_two = ScalarExpr::Column(0) + .call_binary( + ScalarExpr::Literal(Value::from(2i64), ConcreteDataType::int64_datatype()), + BinaryFunc::AddInt64, + ) + .call_binary( + ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now), + BinaryFunc::Gt, + ); + assert!(gt_now_minus_two.contains_temporal()); + + let mfp = MapFilterProject::new(3) + .filter(vec![ + // col(0) <= now() + lte_now, + // col(0) + 2 > now() + gt_now_minus_two, + ]) + .unwrap() + .project(vec![0]) + .unwrap(); + + let mfp = MfpPlan::create_from(mfp).unwrap(); + let expected = vec![ + ( + 0, + vec![ + (Row::new(vec![Value::from(4i64)]), 4, 1), + (Row::new(vec![Value::from(4i64)]), 6, -1), + ], + ), + ( + 5, + vec![ + (Row::new(vec![Value::from(4i64)]), 5, 1), + (Row::new(vec![Value::from(4i64)]), 6, -1), + ], + ), + (10, vec![]), + ]; + for (sys_time, expected) in expected { + let mut values = vec![Value::from(4i64), Value::from(2i64), Value::from(3i64)]; + let ret = mfp + .evaluate::(&mut values, sys_time, 1) + .collect::, _>>() + .unwrap(); + assert_eq!(ret, expected); + } + } + + #[test] + fn test_mfp() { + use crate::expr::func::BinaryFunc; + let mfp = MapFilterProject::new(3) + .map(vec![ + ScalarExpr::Column(0).call_binary(ScalarExpr::Column(1), BinaryFunc::Lt), + ScalarExpr::Column(1).call_binary(ScalarExpr::Column(2), BinaryFunc::Lt), + ]) + .unwrap() + .project(vec![3, 4]) + .unwrap(); + assert!(!mfp.is_identity()); + let mfp = MapFilterProject::compose(mfp, MapFilterProject::new(2)).unwrap(); + { + let mfp_0 = mfp.as_map_filter_project(); + let same = MapFilterProject::new(3) + .map(mfp_0.0) + .unwrap() + .filter(mfp_0.1) + .unwrap() + .project(mfp_0.2) + .unwrap(); + assert_eq!(mfp, same); + } + assert_eq!(mfp.demand().len(), 3); + let mut mfp = mfp; + mfp.permute(BTreeMap::from([(0, 2), (2, 0), (1, 1)]), 3) + .unwrap(); + assert_eq!( + mfp, + MapFilterProject::new(3) + .map(vec![ + ScalarExpr::Column(2).call_binary(ScalarExpr::Column(1), BinaryFunc::Lt), + ScalarExpr::Column(1).call_binary(ScalarExpr::Column(0), BinaryFunc::Lt), + ]) + .unwrap() + .project(vec![3, 4]) + .unwrap() + ); + let safe_mfp = SafeMfpPlan { mfp }; + let mut values = vec![Value::from(4), Value::from(2), Value::from(3)]; + let ret = safe_mfp + .evaluate_into(&mut values, &mut Row::empty()) + .unwrap() + .unwrap(); + assert_eq!(ret, Row::pack(vec![Value::from(false), Value::from(true)])); + } + + #[test] + fn manipulation_mfp() { + // give a input of 4 columns + let mfp = MapFilterProject::new(4); + // append a expression to the mfp'input row that get the sum of the first 3 columns + let mfp = mfp + .map(vec![ScalarExpr::Column(0) + .call_binary(ScalarExpr::Column(1), BinaryFunc::AddInt32) + .call_binary(ScalarExpr::Column(2), BinaryFunc::AddInt32)]) + .unwrap(); + // only retain sum result + let mfp = mfp.project(vec![4]).unwrap(); + // accept only if if the sum is greater than 10 + let mfp = mfp + .filter(vec![ScalarExpr::Column(0).call_binary( + ScalarExpr::Literal(Value::from(10i32), ConcreteDataType::int32_datatype()), + BinaryFunc::Gt, + )]) + .unwrap(); + let mut input1 = vec![ + Value::from(4), + Value::from(2), + Value::from(3), + Value::from("abc"), + ]; + let safe_mfp = SafeMfpPlan { mfp }; + let ret = safe_mfp + .evaluate_into(&mut input1, &mut Row::empty()) + .unwrap(); + assert_eq!(ret, None); + let mut input2 = vec![ + Value::from(5), + Value::from(2), + Value::from(4), + Value::from("abc"), + ]; + let ret = safe_mfp + .evaluate_into(&mut input2, &mut Row::empty()) + .unwrap(); + assert_eq!(ret, Some(Row::pack(vec![Value::from(11)]))); + } + + #[test] + fn test_permute() { + let mfp = MapFilterProject::new(3) + .map(vec![ + ScalarExpr::Column(0).call_binary(ScalarExpr::Column(1), BinaryFunc::Lt) + ]) + .unwrap() + .filter(vec![ + ScalarExpr::Column(0).call_binary(ScalarExpr::Column(1), BinaryFunc::Gt) + ]) + .unwrap() + .project(vec![0, 1]) + .unwrap(); + assert_eq!(mfp.demand(), BTreeSet::from([0, 1])); + let mut less = mfp.clone(); + less.permute(BTreeMap::from([(1, 0), (0, 1)]), 2).unwrap(); + + let mut more = mfp.clone(); + more.permute(BTreeMap::from([(0, 1), (1, 2), (2, 0)]), 4) + .unwrap(); + } + + #[test] + fn mfp_test_cast_and_filter() { + let mfp = MapFilterProject::new(3) + .map(vec![ScalarExpr::Column(0).call_unary(UnaryFunc::Cast( + ConcreteDataType::int32_datatype(), + ))]) + .unwrap() + .filter(vec![ + ScalarExpr::Column(3).call_binary(ScalarExpr::Column(1), BinaryFunc::Gt) + ]) + .unwrap() + .project([0, 1, 2]) + .unwrap(); + let mut input1 = vec![ + Value::from(4i64), + Value::from(2), + Value::from(3), + Value::from(53), + ]; + let safe_mfp = SafeMfpPlan { mfp }; + let ret = safe_mfp.evaluate_into(&mut input1, &mut Row::empty()); + assert!(matches!(ret, Err(EvalError::InvalidArgument { .. }))); + + let input2 = vec![Value::from(4i64), Value::from(2), Value::from(3)]; + let ret = safe_mfp + .evaluate_into(&mut input2.clone(), &mut Row::empty()) + .unwrap(); + assert_eq!(ret, Some(Row::new(input2))); + + let mut input3 = vec![Value::from(4i64), Value::from(5), Value::from(2)]; + let ret = safe_mfp + .evaluate_into(&mut input3, &mut Row::empty()) + .unwrap(); + assert_eq!(ret, None); + } + + #[test] + fn test_mfp_out_of_order() { + let mfp = MapFilterProject::new(3) + .project(vec![2, 1, 0]) + .unwrap() + .filter(vec![ + ScalarExpr::Column(0).call_binary(ScalarExpr::Column(1), BinaryFunc::Gt) + ]) + .unwrap() + .map(vec![ + ScalarExpr::Column(0).call_binary(ScalarExpr::Column(1), BinaryFunc::Lt) + ]) + .unwrap() + .project(vec![3]) + .unwrap(); + let mut input1 = vec![Value::from(2), Value::from(3), Value::from(4)]; + let safe_mfp = SafeMfpPlan { mfp }; + let ret = safe_mfp.evaluate_into(&mut input1, &mut Row::empty()); + assert_eq!(ret.unwrap(), Some(Row::new(vec![Value::from(false)]))); + } + #[test] + fn test_mfp_chore() { + // project keeps permute columns until it becomes the identity permutation + let mfp = MapFilterProject::new(3) + .project([1, 2, 0]) + .unwrap() + .project([1, 2, 0]) + .unwrap() + .project([1, 2, 0]) + .unwrap(); + assert_eq!(mfp, MapFilterProject::new(3)); + } +} diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index fa03bb9f1912..1bffdebd71f2 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -17,6 +17,7 @@ use std::collections::{BTreeMap, BTreeSet}; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; use serde::{Deserialize, Serialize}; +use snafu::ensure; use crate::expr::error::{ EvalError, InvalidArgumentSnafu, OptimizeSnafu, UnsupportedTemporalFilterSnafu, @@ -82,7 +83,7 @@ impl ScalarExpr { match self { ScalarExpr::Column(index) => Ok(values[*index].clone()), ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()), - ScalarExpr::CallUnmaterializable(f) => OptimizeSnafu { + ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu { reason: "Can't eval unmaterializable function".to_string(), } .fail(), @@ -105,12 +106,27 @@ impl ScalarExpr { /// This method is applicable even when `permutation` is not a /// strict permutation, and it only needs to have entries for /// each column referenced in `self`. - pub fn permute(&mut self, permutation: &[usize]) { + pub fn permute(&mut self, permutation: &[usize]) -> Result<(), EvalError> { + // check first so that we don't end up with a partially permuted expression + ensure!( + self.get_all_ref_columns() + .into_iter() + .all(|i| i < permutation.len()), + InvalidArgumentSnafu { + reason: format!( + "permutation {:?} is not a valid permutation for expression {:?}", + permutation, self + ), + } + ); + self.visit_mut_post_nolimit(&mut |e| { if let ScalarExpr::Column(old_i) = e { *old_i = permutation[*old_i]; } - }); + Ok(()) + })?; + Ok(()) } /// Rewrites column indices with their value in `permutation`. @@ -118,12 +134,25 @@ impl ScalarExpr { /// This method is applicable even when `permutation` is not a /// strict permutation, and it only needs to have entries for /// each column referenced in `self`. - pub fn permute_map(&mut self, permutation: &BTreeMap) { + pub fn permute_map(&mut self, permutation: &BTreeMap) -> Result<(), EvalError> { + // check first so that we don't end up with a partially permuted expression + ensure!( + self.get_all_ref_columns() + .is_subset(&permutation.keys().cloned().collect()), + InvalidArgumentSnafu { + reason: format!( + "permutation {:?} is not a valid permutation for expression {:?}", + permutation, self + ), + } + ); + self.visit_mut_post_nolimit(&mut |e| { if let ScalarExpr::Column(old_i) = e { *old_i = permutation[old_i]; } - }); + Ok(()) + }) } /// Returns the set of columns that are referenced by `self`. @@ -133,7 +162,9 @@ impl ScalarExpr { if let ScalarExpr::Column(i) = e { support.insert(*i); } - }); + Ok(()) + }) + .unwrap(); support } @@ -180,70 +211,72 @@ impl ScalarExpr { impl ScalarExpr { /// visit post-order without stack call limit, but may cause stack overflow - fn visit_post_nolimit(&self, f: &mut F) + fn visit_post_nolimit(&self, f: &mut F) -> Result<(), EvalError> where - F: FnMut(&Self), + F: FnMut(&Self) -> Result<(), EvalError>, { - self.visit_children(|e| e.visit_post_nolimit(f)); - f(self); + self.visit_children(|e| e.visit_post_nolimit(f))?; + f(self) } - fn visit_children(&self, mut f: F) + fn visit_children(&self, mut f: F) -> Result<(), EvalError> where - F: FnMut(&Self), + F: FnMut(&Self) -> Result<(), EvalError>, { match self { ScalarExpr::Column(_) | ScalarExpr::Literal(_, _) - | ScalarExpr::CallUnmaterializable(_) => (), + | ScalarExpr::CallUnmaterializable(_) => Ok(()), ScalarExpr::CallUnary { expr, .. } => f(expr), ScalarExpr::CallBinary { expr1, expr2, .. } => { - f(expr1); - f(expr2); + f(expr1)?; + f(expr2) } ScalarExpr::CallVariadic { exprs, .. } => { for expr in exprs { - f(expr); + f(expr)?; } + Ok(()) } ScalarExpr::If { cond, then, els } => { - f(cond); - f(then); - f(els); + f(cond)?; + f(then)?; + f(els) } } } - fn visit_mut_post_nolimit(&mut self, f: &mut F) + fn visit_mut_post_nolimit(&mut self, f: &mut F) -> Result<(), EvalError> where - F: FnMut(&mut Self), + F: FnMut(&mut Self) -> Result<(), EvalError>, { - self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f)); - f(self); + self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f))?; + f(self) } - fn visit_mut_children(&mut self, mut f: F) + fn visit_mut_children(&mut self, mut f: F) -> Result<(), EvalError> where - F: FnMut(&mut Self), + F: FnMut(&mut Self) -> Result<(), EvalError>, { match self { ScalarExpr::Column(_) | ScalarExpr::Literal(_, _) - | ScalarExpr::CallUnmaterializable(_) => (), + | ScalarExpr::CallUnmaterializable(_) => Ok(()), ScalarExpr::CallUnary { expr, .. } => f(expr), ScalarExpr::CallBinary { expr1, expr2, .. } => { - f(expr1); - f(expr2); + f(expr1)?; + f(expr2) } ScalarExpr::CallVariadic { exprs, .. } => { for expr in exprs { - f(expr); + f(expr)?; } + Ok(()) } ScalarExpr::If { cond, then, els } => { - f(cond); - f(then); - f(els); + f(cond)?; + f(then)?; + f(els) } } } @@ -257,7 +290,9 @@ impl ScalarExpr { if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e { contains = true; } - }); + Ok(()) + }) + .unwrap(); contains } @@ -317,6 +352,8 @@ impl ScalarExpr { #[cfg(test)] mod test { + use datatypes::arrow::array::Scalar; + use super::*; #[test] fn test_extract_bound() { @@ -390,9 +427,21 @@ mod test { // EvalError is not Eq, so we need to compare the error message match (actual, expected) { (Ok(l), Ok(r)) => assert_eq!(l, r), - (Err(l), Err(r)) => assert!(matches!(l, r)), (l, r) => panic!("expected: {:?}, actual: {:?}", r, l), } } } + + #[test] + fn test_bad_permute() { + let mut expr = ScalarExpr::Column(4); + let permutation = vec![1, 2, 3]; + let res = expr.permute(&permutation); + assert!(matches!(res, Err(EvalError::InvalidArgument { .. }))); + + let mut expr = ScalarExpr::Column(0); + let permute_map = BTreeMap::from([(1, 2), (3, 4)]); + let res = expr.permute_map(&permute_map); + assert!(matches!(res, Err(EvalError::InvalidArgument { .. }))); + } } diff --git a/src/flow/src/lib.rs b/src/flow/src/lib.rs index a60310504764..c144f8ab50be 100644 --- a/src/flow/src/lib.rs +++ b/src/flow/src/lib.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![allow(unused)] +#![allow(dead_code)] +#![allow(unused_imports)] // allow unused for now because it should be use later mod adapter; mod expr;