From b6585e358105a716e2853179ca52254aa4b24ae6 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Fri, 28 Jun 2024 01:17:46 +0800 Subject: [PATCH] refactor(flow): make `from_substrait_*` async& worker handle refactor (#4210) * refactor: use oneshot to receive result * refactor: make from_substrait_* async * refacrot: remove serde for plan&expr --- Cargo.lock | 1 + src/flow/Cargo.toml | 1 + src/flow/src/adapter/worker.rs | 152 ++++++++++-------------------- src/flow/src/expr/func.rs | 2 +- src/flow/src/expr/linear.rs | 4 +- src/flow/src/expr/relation.rs | 3 +- src/flow/src/expr/scalar.rs | 39 ++------ src/flow/src/plan.rs | 4 +- src/flow/src/plan/join.rs | 8 +- src/flow/src/plan/reduce.rs | 8 +- src/flow/src/transform.rs | 2 +- src/flow/src/transform/aggr.rs | 72 +++++++++----- src/flow/src/transform/expr.rs | 122 ++++++++++++++---------- src/flow/src/transform/literal.rs | 2 +- src/flow/src/transform/plan.rs | 40 ++++---- src/flow/src/utils.rs | 4 +- 16 files changed, 223 insertions(+), 241 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f740010071dc..d6fca8241d77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3949,6 +3949,7 @@ version = "0.8.2" dependencies = [ "api", "arrow-schema", + "async-recursion", "async-trait", "bytes", "catalog", diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index 285f8dbeec41..fcf33e45fe44 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] api.workspace = true arrow-schema.workspace = true +async-recursion = "1.0" async-trait.workspace = true bytes.workspace = true catalog.workspace = true diff --git a/src/flow/src/adapter/worker.rs b/src/flow/src/adapter/worker.rs index 4d9ad2f52447..f69a396cda27 100644 --- a/src/flow/src/adapter/worker.rs +++ b/src/flow/src/adapter/worker.rs @@ -15,14 +15,14 @@ //! For single-thread flow worker use std::collections::{BTreeMap, VecDeque}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use common_telemetry::info; use enum_as_inner::EnumAsInner; use hydroflow::scheduled::graph::Hydroflow; -use snafu::{ensure, OptionExt}; -use tokio::sync::{broadcast, mpsc, Mutex}; +use snafu::ensure; +use tokio::sync::{broadcast, mpsc, oneshot, Mutex}; use crate::adapter::error::{Error, FlowAlreadyExistSnafu, InternalSnafu, UnexpectedSnafu}; use crate::adapter::FlowId; @@ -39,7 +39,7 @@ type ReqId = usize; pub fn create_worker<'a>() -> (WorkerHandle, Worker<'a>) { let (itc_client, itc_server) = create_inter_thread_call(); let worker_handle = WorkerHandle { - itc_client: Mutex::new(itc_client), + itc_client, shutdown: AtomicBool::new(false), }; let worker = Worker { @@ -106,7 +106,7 @@ impl<'subgraph> ActiveDataflowState<'subgraph> { #[derive(Debug)] pub struct WorkerHandle { - itc_client: Mutex, + itc_client: InterThreadCallClient, shutdown: AtomicBool, } @@ -122,12 +122,7 @@ impl WorkerHandle { } ); - let ret = self - .itc_client - .lock() - .await - .call_with_resp(create_reqs) - .await?; + let ret = self.itc_client.call_with_resp(create_reqs).await?; ret.into_create().map_err(|ret| { InternalSnafu { reason: format!( @@ -141,7 +136,8 @@ impl WorkerHandle { /// remove task, return task id pub async fn remove_flow(&self, flow_id: FlowId) -> Result { let req = Request::Remove { flow_id }; - let ret = self.itc_client.lock().await.call_with_resp(req).await?; + + let ret = self.itc_client.call_with_resp(req).await?; ret.into_remove().map_err(|ret| { InternalSnafu { @@ -157,15 +153,12 @@ impl WorkerHandle { /// /// the returned error is unrecoverable, and the worker should be shutdown/rebooted pub async fn run_available(&self, now: repr::Timestamp) -> Result<(), Error> { - self.itc_client - .lock() - .await - .call_no_resp(Request::RunAvail { now }) + self.itc_client.call_no_resp(Request::RunAvail { now }) } pub async fn contains_flow(&self, flow_id: FlowId) -> Result { let req = Request::ContainTask { flow_id }; - let ret = self.itc_client.lock().await.call_with_resp(req).await?; + let ret = self.itc_client.call_with_resp(req).await?; ret.into_contain_task().map_err(|ret| { InternalSnafu { @@ -178,23 +171,9 @@ impl WorkerHandle { } /// shutdown the worker - pub async fn shutdown(&self) -> Result<(), Error> { + pub fn shutdown(&self) -> Result<(), Error> { if !self.shutdown.fetch_or(true, Ordering::SeqCst) { - self.itc_client.lock().await.call_no_resp(Request::Shutdown) - } else { - UnexpectedSnafu { - reason: "Worker already shutdown", - } - .fail() - } - } - - /// shutdown the worker - pub fn shutdown_blocking(&self) -> Result<(), Error> { - if !self.shutdown.fetch_or(true, Ordering::SeqCst) { - self.itc_client - .blocking_lock() - .call_no_resp(Request::Shutdown) + self.itc_client.call_no_resp(Request::Shutdown) } else { UnexpectedSnafu { reason: "Worker already shutdown", @@ -206,8 +185,7 @@ impl WorkerHandle { impl Drop for WorkerHandle { fn drop(&mut self) { - let ret = futures::executor::block_on(async { self.shutdown().await }); - if let Err(ret) = ret { + if let Err(ret) = self.shutdown() { common_telemetry::error!( ret; "While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state." @@ -276,7 +254,7 @@ impl<'s> Worker<'s> { /// Run the worker, blocking, until shutdown signal is received pub fn run(&mut self) { loop { - let (req_id, req) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() { + let (req, ret_tx) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() { ret } else { common_telemetry::error!( @@ -285,19 +263,26 @@ impl<'s> Worker<'s> { break; }; - let ret = self.handle_req(req_id, req); - match ret { - Ok(Some((id, resp))) => { - if let Err(err) = self.itc_server.blocking_lock().resp(id, resp) { + let ret = self.handle_req(req); + match (ret, ret_tx) { + (Ok(Some(resp)), Some(ret_tx)) => { + if let Err(err) = ret_tx.send(resp) { common_telemetry::error!( err; - "Worker's itc server has been closed unexpectedly, shutting down worker" + "Result receiver is dropped, can't send result" ); - break; }; } - Ok(None) => continue, - Err(()) => { + (Ok(None), None) => continue, + (Ok(Some(resp)), None) => { + common_telemetry::error!( + "Expect no result for current request, but found {resp:?}" + ) + } + (Ok(None), Some(_)) => { + common_telemetry::error!("Expect result for current request, but found nothing") + } + (Err(()), _) => { break; } } @@ -315,7 +300,7 @@ impl<'s> Worker<'s> { /// handle request, return response if any, Err if receive shutdown signal /// /// return `Err(())` if receive shutdown request - fn handle_req(&mut self, req_id: ReqId, req: Request) -> Result, ()> { + fn handle_req(&mut self, req: Request) -> Result, ()> { let ret = match req { Request::Create { flow_id, @@ -339,16 +324,13 @@ impl<'s> Worker<'s> { create_if_not_exists, err_collector, ); - Some(( - req_id, - Response::Create { - result: task_create_result, - }, - )) + Some(Response::Create { + result: task_create_result, + }) } Request::Remove { flow_id } => { let ret = self.remove_flow(flow_id); - Some((req_id, Response::Remove { result: ret })) + Some(Response::Remove { result: ret }) } Request::RunAvail { now } => { self.run_tick(now); @@ -356,7 +338,7 @@ impl<'s> Worker<'s> { } Request::ContainTask { flow_id } => { let ret = self.task_states.contains_key(&flow_id); - Some((req_id, Response::ContainTask { result: ret })) + Some(Response::ContainTask { result: ret }) } Request::Shutdown => return Err(()), }; @@ -406,83 +388,50 @@ enum Response { fn create_inter_thread_call() -> (InterThreadCallClient, InterThreadCallServer) { let (arg_send, arg_recv) = mpsc::unbounded_channel(); - let (ret_send, ret_recv) = mpsc::unbounded_channel(); let client = InterThreadCallClient { - call_id: AtomicUsize::new(0), arg_sender: arg_send, - ret_recv, - }; - let server = InterThreadCallServer { - arg_recv, - ret_sender: ret_send, }; + let server = InterThreadCallServer { arg_recv }; (client, server) } #[derive(Debug)] struct InterThreadCallClient { - call_id: AtomicUsize, - arg_sender: mpsc::UnboundedSender<(ReqId, Request)>, - ret_recv: mpsc::UnboundedReceiver<(ReqId, Response)>, + arg_sender: mpsc::UnboundedSender<(Request, Option>)>, } impl InterThreadCallClient { - /// call without expecting responses or blocking fn call_no_resp(&self, req: Request) -> Result<(), Error> { - // TODO(discord9): relax memory order later - let call_id = self.call_id.fetch_add(1, Ordering::SeqCst); - self.arg_sender - .send((call_id, req)) - .map_err(from_send_error) + self.arg_sender.send((req, None)).map_err(from_send_error) } - /// call blocking, and return the result - async fn call_with_resp(&mut self, req: Request) -> Result { - // TODO(discord9): relax memory order later - let call_id = self.call_id.fetch_add(1, Ordering::SeqCst); + async fn call_with_resp(&self, req: Request) -> Result { + let (tx, rx) = oneshot::channel(); self.arg_sender - .send((call_id, req)) + .send((req, Some(tx))) .map_err(from_send_error)?; - - // TODO(discord9): better inter thread call impl, i.e. support multiple client(also consider if it's necessary) - // since one node manger might manage multiple worker, but one worker should only belong to one node manager - let (ret_call_id, ret) = self - .ret_recv - .recv() - .await - .context(InternalSnafu { reason: "InterThreadCallClient call_blocking failed, ret_recv has been closed and there are no remaining messages in the channel's buffer" })?; - - ensure!( - ret_call_id == call_id, + rx.await.map_err(|_| { InternalSnafu { - reason: "call id mismatch, worker/worker handler should be in sync", + reason: "Sender is dropped", } - ); - Ok(ret) + .build() + }) } } #[derive(Debug)] struct InterThreadCallServer { - pub arg_recv: mpsc::UnboundedReceiver<(ReqId, Request)>, - pub ret_sender: mpsc::UnboundedSender<(ReqId, Response)>, + pub arg_recv: mpsc::UnboundedReceiver<(Request, Option>)>, } impl InterThreadCallServer { - pub async fn recv(&mut self) -> Option<(usize, Request)> { + pub async fn recv(&mut self) -> Option<(Request, Option>)> { self.arg_recv.recv().await } - pub fn blocking_recv(&mut self) -> Option<(usize, Request)> { + pub fn blocking_recv(&mut self) -> Option<(Request, Option>)> { self.arg_recv.blocking_recv() } - - /// Send response back to the client - pub fn resp(&self, call_id: ReqId, resp: Response) -> Result<(), Error> { - self.ret_sender - .send((call_id, resp)) - .map_err(from_send_error) - } } fn from_send_error(err: mpsc::error::SendError) -> Error { @@ -546,7 +495,10 @@ mod test { create_if_not_exists: true, err_collector: ErrCollector::default(), }; - handle.create_flow(create_reqs).await.unwrap(); + assert_eq!( + handle.create_flow(create_reqs).await.unwrap(), + Some(flow_id) + ); tx.send((Row::empty(), 0, 0)).unwrap(); handle.run_available(0).await.unwrap(); assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty()); diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index 2109356ad621..c30b67dbffa4 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -43,7 +43,7 @@ use crate::repr::{self, value_to_internal_ts, Row}; /// UnmaterializableFunc is a function that can't be eval independently, /// and require special handling -#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] +#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)] pub enum UnmaterializableFunc { Now, CurrentSchema, diff --git a/src/flow/src/expr/linear.rs b/src/flow/src/expr/linear.rs index dcfed4eb0d28..b0e32c94d87b 100644 --- a/src/flow/src/expr/linear.rs +++ b/src/flow/src/expr/linear.rs @@ -49,7 +49,7 @@ use crate::repr::{self, value_to_internal_ts, Diff, Row}; /// expressions in `self.expressions`, even though this is not something /// we can directly evaluate. The plan creation methods will defensively /// ensure that the right thing happens. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct MapFilterProject { /// A sequence of expressions that should be appended to the row. /// @@ -462,7 +462,7 @@ impl MapFilterProject { } /// A wrapper type which indicates it is safe to simply evaluate all expressions. -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct SafeMfpPlan { /// the inner `MapFilterProject` that is safe to evaluate. pub(crate) mfp: MapFilterProject, diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index a873c267b1a5..661f716dcd29 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -23,7 +23,7 @@ mod accum; mod func; /// Describes an aggregation expression. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct AggregateExpr { /// Names the aggregation function. pub func: AggregateFunc, @@ -32,6 +32,5 @@ pub struct AggregateExpr { /// so it only used in generate KeyValPlan from AggregateExpr pub expr: ScalarExpr, /// Should the aggregation be applied only to distinct results in each group. - #[serde(default)] pub distinct: bool, } diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 7335511be0f0..591d2c246fc1 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -43,7 +43,7 @@ use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFun use crate::repr::{ColumnType, RelationDesc, RelationType}; use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions}; /// A scalar expression with a known type. -#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] +#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)] pub struct TypedExpr { /// The expression. pub expr: ScalarExpr, @@ -129,7 +129,7 @@ impl TypedExpr { } /// A scalar expression, which can be evaluated to a value. -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ScalarExpr { /// A column of the input row Column(usize), @@ -191,9 +191,9 @@ impl DfScalarFunction { }) } - pub fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result { + pub async fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result { Ok(Self { - fn_impl: raw_fn.get_fn_impl()?, + fn_impl: raw_fn.get_fn_impl().await?, df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?), raw_fn, }) @@ -264,27 +264,7 @@ impl DfScalarFunction { } } -// simply serialize the raw_fn instead of derive to avoid complex deserialize of struct -impl Serialize for DfScalarFunction { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.raw_fn.serialize(serializer) - } -} - -impl<'de> serde::de::Deserialize<'de> for DfScalarFunction { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - let raw_fn = RawDfScalarFn::deserialize(deserializer)?; - DfScalarFunction::try_from_raw_fn(raw_fn).map_err(serde::de::Error::custom) - } -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct RawDfScalarFn { /// The raw bytes encoded datafusion scalar function pub(crate) f: bytes::BytesMut, @@ -311,7 +291,7 @@ impl RawDfScalarFn { extensions, }) } - fn get_fn_impl(&self) -> Result, Error> { + async fn get_fn_impl(&self) -> Result, Error> { let f = ScalarFunction::decode(&mut self.f.as_ref()) .context(DecodeRelSnafu) .map_err(BoxedError::new) @@ -320,7 +300,7 @@ impl RawDfScalarFn { let input_schema = &self.input_schema; let extensions = &self.extensions; - from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions) + from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions).await } } @@ -894,10 +874,7 @@ mod test { .unwrap(); let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]); let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap(); - let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).unwrap(); - let as_str = serde_json::to_string(&df_func).unwrap(); - let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap(); - assert_eq!(df_func, from_str); + let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await.unwrap(); assert_eq!( df_func .eval(&[Value::Null], &[ScalarExpr::Column(0)]) diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index 6e4b13673302..95816b17cb03 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -33,7 +33,7 @@ pub(crate) use crate::plan::reduce::{AccumulablePlan, AggrWithIndex, KeyValPlan, use crate::repr::{ColumnType, DiffRow, RelationDesc, RelationType}; /// A plan for a dataflow component. But with type to indicate the output type of the relation. -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct TypedPlan { /// output type of the relation pub schema: RelationDesc, @@ -121,7 +121,7 @@ impl TypedPlan { /// TODO(discord9): support `TableFunc`(by define FlatMap that map 1 to n) /// Plan describe how to transform data in dataflow -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum Plan { /// A constant collection of rows. Constant { rows: Vec }, diff --git a/src/flow/src/plan/join.rs b/src/flow/src/plan/join.rs index 13bb95f51159..4acf0db2342e 100644 --- a/src/flow/src/plan/join.rs +++ b/src/flow/src/plan/join.rs @@ -18,13 +18,13 @@ use crate::expr::ScalarExpr; use crate::plan::SafeMfpPlan; /// TODO(discord9): consider impl more join strategies -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub enum JoinPlan { Linear(LinearJoinPlan), } /// Determine if a given row should stay in the output. And apply a map filter project before output the row -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct JoinFilter { /// each element in the outer vector will check if each expr in itself can be eval to same value /// if not, the row will be filtered out. Useful for equi-join(join based on equality of some columns) @@ -37,7 +37,7 @@ pub struct JoinFilter { /// /// A linear join is a sequence of stages, each of which introduces /// a new collection. Each stage is represented by a [LinearStagePlan]. -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct LinearJoinPlan { /// The source relation from which we start the join. pub source_relation: usize, @@ -60,7 +60,7 @@ pub struct LinearJoinPlan { /// Each stage is a binary join between the current accumulated /// join results, and a new collection. The former is referred to /// as the "stream" and the latter the "lookup". -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct LinearStagePlan { /// The index of the relation into which we will look up. pub lookup_relation: usize, diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index 85c84a42f342..3d0d8b356a37 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use crate::expr::{AggregateExpr, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr}; /// Describe how to extract key-value pair from a `Row` -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct KeyValPlan { /// Extract key from row pub key_plan: SafeMfpPlan, @@ -27,7 +27,7 @@ pub struct KeyValPlan { /// TODO(discord9): def&impl of Hierarchical aggregates(for min/max with support to deletion) and /// basic aggregates(for other aggregate functions) and mixed aggregate -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum ReducePlan { /// Plan for not computing any aggregations, just determining the set of /// distinct keys. @@ -38,7 +38,7 @@ pub enum ReducePlan { } /// Accumulable plan for the execution of a reduction. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct AccumulablePlan { /// All of the aggregations we were asked to compute, stored /// in order. @@ -57,7 +57,7 @@ pub struct AccumulablePlan { /// Invariant: the output index is the index of the aggregation in `full_aggrs` /// which means output index is always smaller than the length of `full_aggrs` -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct AggrWithIndex { /// aggregation expression pub expr: AggregateExpr, diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 35d811a03732..e86dac85fb7f 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -140,7 +140,7 @@ pub async fn sql_to_flow_plan( .map_err(BoxedError::new) .context(ExternalSnafu)?; - let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan)?; + let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?; Ok(flow_plan) } diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 3cc1512692d5..6456f00a5c75 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -58,7 +58,7 @@ use crate::repr::{self, ColumnType, RelationDesc, RelationType}; use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; impl TypedExpr { - fn from_substrait_agg_grouping( + async fn from_substrait_agg_grouping( ctx: &mut FlownodeContext, groupings: &[Grouping], typ: &RelationDesc, @@ -69,7 +69,7 @@ impl TypedExpr { match groupings.len() { 1 => { for e in &groupings[0].grouping_expressions { - let x = TypedExpr::from_substrait_rex(e, typ, extensions)?; + let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?; group_expr.push(x); } } @@ -87,7 +87,7 @@ impl AggregateExpr { /// Convert list of `Measure` into Flow's AggregateExpr /// /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function - fn from_substrait_agg_measures( + async fn from_substrait_agg_measures( ctx: &mut FlownodeContext, measures: &[Measure], typ: &RelationDesc, @@ -98,11 +98,15 @@ impl AggregateExpr { let mut post_maps = vec![]; for m in measures { - let filter = &m + let filter = match m .filter .as_ref() .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions)) - .transpose()?; + { + Some(fut) => Some(fut.await), + None => None, + } + .transpose()?; let (aggr_expr, post_mfp) = match &m.measure { Some(f) => { @@ -112,9 +116,10 @@ impl AggregateExpr { _ => false, }; AggregateExpr::from_substrait_agg_func( - f, typ, extensions, filter, // TODO(discord9): impl order_by + f, typ, extensions, &filter, // TODO(discord9): impl order_by &None, distinct, ) + .await } None => not_impl_err!("Aggregate without aggregate function is not supported"), }?; @@ -142,7 +147,7 @@ impl AggregateExpr { /// /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)` - pub fn from_substrait_agg_func( + pub async fn from_substrait_agg_func( f: &proto::AggregateFunction, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -157,7 +162,7 @@ impl AggregateExpr { for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - TypedExpr::from_substrait_rex(e, input_schema, extensions) + TypedExpr::from_substrait_rex(e, input_schema, extensions).await } _ => not_impl_err!("Aggregated function argument non-Value type not supported"), }?; @@ -306,13 +311,14 @@ impl TypedPlan { /// The output of aggr plan is: /// /// .. - pub fn from_substrait_agg_rel( + #[async_recursion::async_recursion] + pub async fn from_substrait_agg_rel( ctx: &mut FlownodeContext, agg: &proto::AggregateRel, extensions: &FunctionExtensions, ) -> Result { let input = if let Some(input) = agg.input.as_ref() { - TypedPlan::from_substrait_rel(ctx, input, extensions)? + TypedPlan::from_substrait_rel(ctx, input, extensions).await? } else { return not_impl_err!("Aggregate without an input is not supported"); }; @@ -323,7 +329,8 @@ impl TypedPlan { &agg.groupings, &input.schema, extensions, - )?; + ) + .await?; TypedExpr::expand_multi_value(&input.schema.typ, &group_exprs)? }; @@ -335,7 +342,8 @@ impl TypedPlan { &agg.measures, &input.schema, extensions, - )?; + ) + .await?; let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan( &mut aggr_exprs, @@ -479,7 +487,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan).is_err()); + assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .is_err()); } #[tokio::test] @@ -489,7 +499,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -578,6 +590,7 @@ mod test { }, }, ) + .await .unwrap(), exprs: vec![ScalarExpr::Column(0)], }]) @@ -630,7 +643,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -743,6 +758,7 @@ mod test { ]), }, }) + .await .unwrap(), exprs: vec![ScalarExpr::Column(3)], }, @@ -766,7 +782,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_exprs = vec![ AggregateExpr { @@ -913,7 +931,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -1029,7 +1049,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -1145,7 +1167,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_exprs = vec![ AggregateExpr { @@ -1250,7 +1272,9 @@ mod test { let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_exprs = vec![ AggregateExpr { @@ -1341,7 +1365,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let typ = RelationType::new(vec![ColumnType::new( ConcreteDataType::uint64_datatype(), true, @@ -1404,7 +1428,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -1482,7 +1508,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index a6b312504d28..a10e9b121f8c 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use datafusion_physical_expr::PhysicalExpr; use datatypes::data_type::ConcreteDataType as CDT; -use itertools::Itertools; use snafu::{OptionExt, ResultExt}; use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference; use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField; @@ -60,7 +59,7 @@ fn typename_to_cdt(name: &str) -> CDT { } /// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`] -pub(crate) fn from_scalar_fn_to_df_fn_impl( +pub(crate) async fn from_scalar_fn_to_df_fn_impl( f: &ScalarFunction, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -70,7 +69,7 @@ pub(crate) fn from_scalar_fn_to_df_fn_impl( }; let schema = input_schema.to_df_schema()?; - let df_expr = futures::executor::block_on(async { + let df_expr = // TODO(discord9): consider coloring everything async.... substrait::df_logical_plan::consumer::from_substrait_rex( &datafusion::prelude::SessionContext::new(), @@ -79,7 +78,7 @@ pub(crate) fn from_scalar_fn_to_df_fn_impl( &extensions.inner_ref(), ) .await - }); + ; let expr = df_expr.map_err(|err| { DatafusionSnafu { raw: err, @@ -138,7 +137,7 @@ fn rewrite_scalar_function(f: &ScalarFunction) -> ScalarFunction { } impl TypedExpr { - pub fn from_substrait_to_datafusion_scalar_func( + pub async fn from_substrait_to_datafusion_scalar_func( f: &ScalarFunction, arg_exprs_typed: Vec, extensions: &FunctionExtensions, @@ -152,7 +151,7 @@ impl TypedExpr { let raw_fn = RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?; - let df_func = DfScalarFunction::try_from_raw_fn(raw_fn)?; + let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?; let expr = ScalarExpr::CallDf { df_scalar_fn: df_func, exprs: arg_exprs, @@ -163,7 +162,7 @@ impl TypedExpr { } /// Convert ScalarFunction into Flow's ScalarExpr - pub fn from_substrait_scalar_func( + pub async fn from_substrait_scalar_func( f: &ScalarFunction, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -178,16 +177,19 @@ impl TypedExpr { ), })?; let arg_len = f.arguments.len(); - let arg_typed_exprs: Vec = f - .arguments - .iter() - .map(|arg| match &arg.arg_type { - Some(ArgType::Value(e)) => { - TypedExpr::from_substrait_rex(e, input_schema, extensions) - } - _ => not_impl_err!("Aggregated function argument non-Value type not supported"), - }) - .try_collect()?; + let arg_typed_exprs: Vec = { + let mut rets = Vec::new(); + for arg in f.arguments.iter() { + let ret = match &arg.arg_type { + Some(ArgType::Value(e)) => { + TypedExpr::from_substrait_rex(e, input_schema, extensions).await + } + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }?; + rets.push(ret); + } + rets + }; // literal's type is determined by the function and type of other args let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs @@ -293,7 +295,8 @@ impl TypedExpr { f, arg_typed_exprs, extensions, - )?; + ) + .await?; Ok(try_as_df) } } @@ -301,38 +304,44 @@ impl TypedExpr { } /// Convert IfThen into Flow's ScalarExpr - pub fn from_substrait_ifthen_rex( + pub async fn from_substrait_ifthen_rex( if_then: &IfThen, input_schema: &RelationDesc, extensions: &FunctionExtensions, ) -> Result { - let ifs: Vec<_> = if_then - .ifs - .iter() - .map(|if_clause| { + let ifs: Vec<_> = { + let mut ifs = Vec::new(); + for if_clause in if_then.ifs.iter() { let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu { reason: "IfThen clause without if", })?; let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu { reason: "IfThen clause without then", })?; - let cond = TypedExpr::from_substrait_rex(proto_if, input_schema, extensions)?; - let then = TypedExpr::from_substrait_rex(proto_then, input_schema, extensions)?; - Ok((cond, then)) - }) - .try_collect()?; + let cond = + TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?; + let then = + TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?; + ifs.push((cond, then)); + } + ifs + }; // if no else is presented - let els = if_then + let els = match if_then .r#else .as_ref() .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions)) - .transpose()? - .unwrap_or_else(|| { - TypedExpr::new( - ScalarExpr::literal_null(), - ColumnType::new_nullable(CDT::null_datatype()), - ) - }); + { + Some(fut) => Some(fut.await), + None => None, + } + .transpose()? + .unwrap_or_else(|| { + TypedExpr::new( + ScalarExpr::literal_null(), + ColumnType::new_nullable(CDT::null_datatype()), + ) + }); fn build_if_then_recur( mut next_if_then: impl Iterator, @@ -356,7 +365,8 @@ impl TypedExpr { Ok(expr_if) } /// Convert Substrait Rex into Flow's ScalarExpr - pub fn from_substrait_rex( + #[async_recursion::async_recursion] + pub async fn from_substrait_rex( e: &Expression, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -377,7 +387,7 @@ impl TypedExpr { if !s.options.is_empty() { return not_impl_err!("In list expression is not supported"); } - TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions) + TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await } Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { @@ -400,16 +410,16 @@ impl TypedExpr { _ => not_impl_err!("unsupported field ref type"), }, Some(RexType::ScalarFunction(f)) => { - TypedExpr::from_substrait_scalar_func(f, input_schema, extensions) + TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await } Some(RexType::IfThen(if_then)) => { - TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions) + TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await } Some(RexType::Cast(cast)) => { let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu { reason: "Cast expression without input", })?; - let input = TypedExpr::from_substrait_rex(input, input_schema, extensions)?; + let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?; let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| { InvalidQuerySnafu { reason: "Cast expression without type", @@ -453,7 +463,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; // optimize binary and to variadic and let filter = ScalarExpr::CallVariadic { @@ -509,7 +519,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]) @@ -534,7 +544,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) @@ -572,7 +582,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]) @@ -611,7 +621,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) @@ -641,8 +651,8 @@ mod test { assert_eq!(flow_plan.unwrap(), expected); } - #[test] - fn test_func_sig() { + #[tokio::test] + async fn test_func_sig() { fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument { use substrait_proto::proto::expression; let expr = Expression { @@ -669,7 +679,9 @@ mod test { let input_schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed(); let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, @@ -695,7 +707,9 @@ mod test { ]) .into_unnamed(); let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, @@ -722,7 +736,9 @@ mod test { ]) .into_unnamed(); let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, @@ -750,7 +766,9 @@ mod test { ]) .into_unnamed(); let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, diff --git a/src/flow/src/transform/literal.rs b/src/flow/src/transform/literal.rs index 9dc93d17c549..1fa5bc86a81c 100644 --- a/src/flow/src/transform/literal.rs +++ b/src/flow/src/transform/literal.rs @@ -172,7 +172,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]) diff --git a/src/flow/src/transform/plan.rs b/src/flow/src/transform/plan.rs index a9d9e29310e9..f1f6ba53dd35 100644 --- a/src/flow/src/transform/plan.rs +++ b/src/flow/src/transform/plan.rs @@ -32,7 +32,7 @@ use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; impl TypedPlan { /// Convert Substrait Plan into Flow's TypedPlan - pub fn from_substrait_plan( + pub async fn from_substrait_plan( ctx: &mut FlownodeContext, plan: &SubPlan, ) -> Result { @@ -45,13 +45,13 @@ impl TypedPlan { match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension)?) + Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension).await?) }, plan_rel::RelType::Root(root) => { let input = root.input.as_ref().with_context(|| InvalidQuerySnafu { reason: "Root relation without input", })?; - Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension)?) + Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension).await?) } }, None => plan_err!("Cannot parse plan relation: None") @@ -64,13 +64,14 @@ impl TypedPlan { } } - pub fn from_substrait_project( + #[async_recursion::async_recursion] + pub async fn from_substrait_project( ctx: &mut FlownodeContext, p: &ProjectRel, extensions: &FunctionExtensions, ) -> Result { let input = if let Some(input) = p.input.as_ref() { - TypedPlan::from_substrait_rel(ctx, input, extensions)? + TypedPlan::from_substrait_rel(ctx, input, extensions).await? } else { return not_impl_err!("Projection without an input is not supported"); }; @@ -93,7 +94,7 @@ impl TypedPlan { let mut exprs: Vec = Vec::with_capacity(p.expressions.len()); for e in &p.expressions { - let expr = TypedExpr::from_substrait_rex(e, &schema_before_expand, extensions)?; + let expr = TypedExpr::from_substrait_rex(e, &schema_before_expand, extensions).await?; exprs.push(expr); } let is_literal = exprs.iter().all(|expr| expr.expr.is_literal()); @@ -131,26 +132,27 @@ impl TypedPlan { } } - pub fn from_substrait_filter( + #[async_recursion::async_recursion] + pub async fn from_substrait_filter( ctx: &mut FlownodeContext, filter: &FilterRel, extensions: &FunctionExtensions, ) -> Result { let input = if let Some(input) = filter.input.as_ref() { - TypedPlan::from_substrait_rel(ctx, input, extensions)? + TypedPlan::from_substrait_rel(ctx, input, extensions).await? } else { return not_impl_err!("Filter without an input is not supported"); }; let expr = if let Some(condition) = filter.condition.as_ref() { - TypedExpr::from_substrait_rex(condition, &input.schema, extensions)? + TypedExpr::from_substrait_rex(condition, &input.schema, extensions).await? } else { return not_impl_err!("Filter without an condition is not valid"); }; input.filter(expr) } - pub fn from_substrait_read( + pub async fn from_substrait_read( ctx: &mut FlownodeContext, read: &ReadRel, _extensions: &FunctionExtensions, @@ -212,16 +214,22 @@ impl TypedPlan { /// Convert Substrait Rel into Flow's TypedPlan /// TODO(discord9): SELECT DISTINCT(does it get compile with something else?) - pub fn from_substrait_rel( + pub async fn from_substrait_rel( ctx: &mut FlownodeContext, rel: &Rel, extensions: &FunctionExtensions, ) -> Result { match &rel.rel_type { - Some(RelType::Project(p)) => Self::from_substrait_project(ctx, p.as_ref(), extensions), - Some(RelType::Filter(filter)) => Self::from_substrait_filter(ctx, filter, extensions), - Some(RelType::Read(read)) => Self::from_substrait_read(ctx, read, extensions), - Some(RelType::Aggregate(agg)) => Self::from_substrait_agg_rel(ctx, agg, extensions), + Some(RelType::Project(p)) => { + Self::from_substrait_project(ctx, p.as_ref(), extensions).await + } + Some(RelType::Filter(filter)) => { + Self::from_substrait_filter(ctx, filter, extensions).await + } + Some(RelType::Read(read)) => Self::from_substrait_read(ctx, read, extensions).await, + Some(RelType::Aggregate(agg)) => { + Self::from_substrait_agg_rel(ctx, agg, extensions).await + } _ => not_impl_err!("Unsupported relation type: {:?}", rel.rel_type), } } @@ -353,7 +361,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]) diff --git a/src/flow/src/utils.rs b/src/flow/src/utils.rs index 69c300ab8f5c..69ff8fa2d248 100644 --- a/src/flow/src/utils.rs +++ b/src/flow/src/utils.rs @@ -40,7 +40,7 @@ pub type Spine = BTreeMap; /// If a key is expired, any future updates to it should be ignored. /// /// Note that key is expired by it's event timestamp (contained in the key), not by the time it's inserted (system timestamp). -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct KeyExpiryManager { /// A map from event timestamp to key, used for expire keys. event_ts_to_key: BTreeMap>, @@ -157,7 +157,7 @@ impl KeyExpiryManager { /// /// Note the two way arrow between reduce operator and arrange, it's because reduce operator need to query existing state /// and also need to update existing state. -#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd)] pub struct Arrangement { /// A name or identifier for the arrangement which can be used for debugging or logging purposes. /// This field is not critical to the functionality but aids in monitoring and management of arrangements.