From ad387ceca84b9f1d2d117fd48d33068aa6b49777 Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Fri, 7 Jul 2023 17:33:18 -0700 Subject: [PATCH] WIP add subquery lowering --- partiql-logical-planner/src/lower.rs | 113 ++++++++++++++++++--------- 1 file changed, 78 insertions(+), 35 deletions(-) diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index c041f8ed..5204b92d 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -17,7 +17,7 @@ use partiql_logical as logical; use partiql_logical::{ AggregateExpression, BagExpr, BagOp, BetweenExpr, BindingsOp, IsTypeExpr, LikeMatch, LikeNonStringNonLiteralMatch, ListExpr, LogicalPlan, OpId, PathComponent, Pattern, - PatternMatchExpr, SortSpecOrder, TupleExpr, ValueExpr, + PatternMatchExpr, SortSpecOrder, SubQueryExpr, TupleExpr, ValueExpr, }; use partiql_value::{BindingsName, Value}; @@ -167,7 +167,7 @@ pub struct AstToLogical<'a> { agg_id: IdGenerator, // output - plan: LogicalPlan, + plan_stack: Vec>, // catalog & data flow data key_registry: name_resolver::KeyRegistry, @@ -235,7 +235,7 @@ impl<'a> AstToLogical<'a> { agg_id: Default::default(), // output - plan: Default::default(), + plan_stack: Default::default(), key_registry: registry, fnsym_tab, @@ -255,7 +255,8 @@ impl<'a> AstToLogical<'a> { errors: self.errors, }); } - Ok(self.plan) + assert_eq!(self.plan_stack.len(), 1); + Ok(self.plan_stack.pop().expect("plan")) } #[inline] @@ -489,6 +490,34 @@ impl<'a> AstToLogical<'a> { fn push_sort_spec(&mut self, spec: logical::SortSpec) { self.sort_stack.last_mut().unwrap().push(spec); } + + #[inline] + fn plan_add_operator(&mut self, op: BindingsOp) -> OpId { + self.plan_stack.last_mut().unwrap().add_operator(op) + } + + #[inline] + fn plan_add_flow(&mut self, src: OpId, dst: OpId) { + self.plan_stack.last_mut().unwrap().add_flow(src, dst) + } + + #[inline] + fn plan_add_flow_with_branch_num(&mut self, src: OpId, dst: OpId, branch_num: u8) { + self.plan_stack + .last_mut() + .unwrap() + .add_flow_with_branch_num(src, dst, branch_num) + } + + #[inline] + fn plan_operator(&mut self, id: OpId) -> Option<&BindingsOp> { + self.plan_stack.last_mut().unwrap().operator(id) + } + + #[inline] + fn plan_operator_as_mut(&mut self, id: OpId) -> Option<&mut BindingsOp> { + self.plan_stack.last_mut().unwrap().operator_as_mut(id) + } } // SQL (and therefore PartiQL) text (and therefore AST) is not lexically-scoped as is the @@ -582,14 +611,15 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { fn enter_top_level_query(&mut self, _query: &'ast ast::TopLevelQuery) -> Traverse { self.enter_benv(); + self.plan_stack.push(LogicalPlan::default()); Traverse::Continue } fn exit_top_level_query(&mut self, _query: &'ast ast::TopLevelQuery) -> Traverse { let mut benv = self.exit_benv(); eq_or_fault!(self, benv.len(), 1, "Expect benv.len() == 1"); let out = benv.pop().unwrap(); - let sink_id = self.plan.add_operator(BindingsOp::Sink); - self.plan.add_flow(out, sink_id); + let sink_id = self.plan_add_operator(BindingsOp::Sink); + self.plan_add_flow(out, sink_id); Traverse::Continue } @@ -597,6 +627,10 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { self.enter_benv(); if let QuerySet::Select(_) = query.set.node { self.enter_q(); + if self.q_stack.len() > 1 { + // entering subquery + self.plan_stack.push(LogicalPlan::default()); + } } Traverse::Continue } @@ -609,10 +643,18 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let mut clauses = clauses.evaluation_order().into_iter(); if let Some(mut src_id) = clauses.next() { for dst_id in clauses { - self.plan.add_flow(src_id, dst_id); + self.plan_add_flow(src_id, dst_id); src_id = dst_id; } - self.push_bexpr(src_id); + if !self.q_stack.is_empty() { + // this is a subquery + let plan = self.plan_stack.pop().unwrap(); + assert!(!self.vexpr_stack.is_empty()); + self.push_vexpr(logical::ValueExpr::SubQueryExpr(SubQueryExpr { plan })); + // don't need bexpr for sink; will use outer query's bexpr + } else { + self.push_bexpr(src_id); + } } } _ => { @@ -623,7 +665,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { ); let mut out = *benv.first().unwrap(); benv.into_iter().skip(1).for_each(|op| { - self.plan.add_flow(out, op); + self.plan_add_flow(out, op); out = op; }); self.push_bexpr(out); @@ -673,12 +715,12 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { SetQuantifier::Distinct => logical::SetQuantifier::Distinct, }; - let id = self.plan.add_operator(BindingsOp::BagOp(BagOp { + let id = self.plan_add_operator(BindingsOp::BagOp(BagOp { bag_op: bag_operator, setq, })); - self.plan.add_flow_with_branch_num(lid, id, 0); - self.plan.add_flow_with_branch_num(rid, id, 1); + self.plan_add_flow_with_branch_num(lid, id, 0); + self.plan_add_flow_with_branch_num(rid, id, 1); self.push_bexpr(id); } QuerySet::Select(_) => {} @@ -686,7 +728,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { eq_or_fault!(self, env.len(), 1, "env.len() != 1"); let expr = env.into_iter().next().unwrap(); let op = BindingsOp::ExprQuery(logical::ExprQuery { expr }); - let id = self.plan.add_operator(op); + let id = self.plan_add_operator(op); self.push_bexpr(id); } QuerySet::Values(_) => { @@ -729,7 +771,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { eq_or_fault!(self, env.len(), 0, "env.len() != 0"); if let Some(SetQuantifier::Distinct) = _projection.setq { - let id = self.plan.add_operator(BindingsOp::Distinct); + let id = self.plan_add_operator(BindingsOp::Distinct); self.current_clauses_mut().distinct.replace(id); } Traverse::Continue @@ -795,7 +837,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { logical::BindingsOp::ProjectValue(logical::ProjectValue { expr }) } }; - let id = self.plan.add_operator(select); + let id = self.plan_add_operator(select); self.current_clauses_mut().select_clause.replace(id); Traverse::Continue } @@ -1181,7 +1223,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { aggregate_exprs: self.aggregate_exprs.clone(), group_as_alias: None, }); - let id = self.plan.add_operator(group_by); + let id = self.plan_add_operator(group_by); self.current_clauses_mut().group_by_clause.replace(id); } Traverse::Continue @@ -1339,7 +1381,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { at_key, }), }; - let id = self.plan.add_operator(bexpr); + let id = self.plan_add_operator(bexpr); self.push_bexpr(id); Traverse::Continue } @@ -1375,17 +1417,17 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let rid = benv.pop().unwrap(); let lid = benv.pop().unwrap(); - let left = Box::new(self.plan.operator(lid).unwrap().clone()); - let right = Box::new(self.plan.operator(rid).unwrap().clone()); + let left = Box::new(self.plan_operator(lid).unwrap().clone()); + let right = Box::new(self.plan_operator(rid).unwrap().clone()); let join = logical::BindingsOp::Join(logical::Join { kind, on, left, right, }); - let join = self.plan.add_operator(join); - self.plan.add_flow_with_branch_num(lid, join, 0); - self.plan.add_flow_with_branch_num(rid, join, 1); + let join = self.plan_add_operator(join); + self.plan_add_flow_with_branch_num(lid, join, 0); + self.plan_add_flow_with_branch_num(rid, join, 1); self.push_bexpr(join); Traverse::Continue } @@ -1417,7 +1459,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let filter = logical::BindingsOp::Filter(logical::Filter { expr: env.pop().unwrap(), }); - let id = self.plan.add_operator(filter); + let id = self.plan_add_operator(filter); self.current_clauses_mut().where_clause.replace(id); Traverse::Continue @@ -1435,7 +1477,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let having = BindingsOp::Having(logical::Having { expr: env.pop().unwrap(), }); - let id = self.plan.add_operator(having); + let id = self.plan_add_operator(having); self.current_clauses_mut().having_clause.replace(id); Traverse::Continue @@ -1485,8 +1527,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { return Traverse::Stop; } let select_clause = self - .plan - .operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None")) + .plan_operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None")) .unwrap(); let mut binding = HashMap::new(); let select_clause_exprs = match select_clause { @@ -1512,16 +1553,18 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { Value::String(s) => (*s).clone(), _ => { // Report error but allow visitor to continue - self.errors.push(AstTransformError::IllegalState( - "Unexpected literal type".to_string(), - )); + // self.errors.push(AstTransformError::IllegalState( + // "Unexpected literal type".to_string(), + // )); + //todo!("Add back error"); "".to_string() } }, _ => { - self.errors.push(AstTransformError::IllegalState( - "Unexpected alias type".to_string(), - )); + //todo!("Add back error"); + // self.errors.push(AstTransformError::IllegalState( + // "Unexpected alias type".to_string(), + // )); return Traverse::Stop; } }; @@ -1546,7 +1589,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { group_as_alias, }); - let id = self.plan.add_operator(group_by); + let id = self.plan_add_operator(group_by); self.current_clauses_mut().group_by_clause.replace(id); Traverse::Continue } @@ -1574,7 +1617,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { fn exit_order_by_expr(&mut self, _order_by_expr: &'ast OrderByExpr) -> Traverse { let specs = self.exit_sort(); let order_by = logical::BindingsOp::OrderBy(logical::OrderBy { specs }); - let id = self.plan.add_operator(order_by); + let id = self.plan_add_operator(order_by); if matches!(self.current_ctx(), Some(QueryContext::Query)) { self.current_clauses_mut().order_by_clause.replace(id); } else { @@ -1647,7 +1690,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { }; let limit_offset = logical::BindingsOp::LimitOffset(logical::LimitOffset { limit, offset }); - let id = self.plan.add_operator(limit_offset); + let id = self.plan_add_operator(limit_offset); if matches!(self.current_ctx(), Some(QueryContext::Query)) { self.current_clauses_mut().limit_offset_clause.replace(id); } else {