From 0773ea04d419f88c0ed852f042e2a12cf84f8419 Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Wed, 12 Jul 2023 11:16:35 -0700 Subject: [PATCH] Implement subquery lowering --- CHANGELOG.md | 1 + partiql-eval/src/eval/evaluable.rs | 4 +- partiql-logical-planner/src/lower.rs | 154 ++++++++++++++++----------- 3 files changed, 95 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1aaecd18..ced80228 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `partiql_ast_passes::static_typer` for type annotating the AST. - Add ability to parse `ORDER BY`, `LIMIT`, `OFFSET` in children of set operators - Add `OUTER` bag operator (`OUTER UNION`, `OUTER INTERSECT`, `OUTER EXCEPT`) implementation +- Add ast to logical plan lowering for subqueries ### Fixes - Fixes parsing of multiple consecutive path wildcards (e.g. `a[*][*][*]`), unpivot (e.g. `a.*.*.*`), and path expressions (e.g. `a[1 + 2][3 + 4][5 + 6]`)—previously these would not parse correctly. diff --git a/partiql-eval/src/eval/evaluable.rs b/partiql-eval/src/eval/evaluable.rs index a4babb83..7131a58a 100644 --- a/partiql-eval/src/eval/evaluable.rs +++ b/partiql-eval/src/eval/evaluable.rs @@ -1422,7 +1422,7 @@ impl Evaluable for EvalOuterIntersect { let bag: Bag = match self.setq { SetQuantifier::All => { let mut lhs = lhs.counts(); - Bag::from_iter(rhs.filter(|elem| match lhs.get_mut(&elem) { + Bag::from_iter(rhs.filter(|elem| match lhs.get_mut(elem) { Some(count) if *count > 0 => { *count -= 1; true @@ -1477,7 +1477,7 @@ impl Evaluable for EvalOuterExcept { let rhs = bagop_iter(self.r_input.take().unwrap_or(Missing)); let mut exclude = rhs.counts(); - let excepted = lhs.filter(|elem| match exclude.get_mut(&elem) { + let excepted = lhs.filter(|elem| match exclude.get_mut(elem) { Some(count) if *count > 0 => { *count -= 1; false diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 014d6395..43cf799a 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -17,7 +17,7 @@ use partiql_logical as logical; use partiql_logical::{ AggregateExpression, BagExpr, BagOp, BetweenExpr, BindingsOp, IsTypeExpr, LikeMatch, LikeNonStringNonLiteralMatch, ListExpr, LogicalPlan, OpId, PathComponent, Pattern, - PatternMatchExpr, SortSpecOrder, TupleExpr, ValueExpr, + PatternMatchExpr, SortSpecOrder, SubQueryExpr, TupleExpr, ValueExpr, }; use partiql_value::{BindingsName, Value}; @@ -80,6 +80,20 @@ macro_rules! not_yet_implemented_err { }; } +#[macro_export] +macro_rules! cur_plan_or_fault { + ($self:ident) => {{ + let plan = $self.plan_stack.last_mut(); + if plan.is_none() { + $self.errors.push(AstTransformError::IllegalState( + "current plan is None".to_string(), + )); + return partiql_ast::visit::Traverse::Stop; + } + plan.unwrap() + }}; +} + #[derive(Copy, Clone, Debug)] enum QueryContext { FromLet, @@ -167,7 +181,7 @@ pub struct AstToLogical<'a> { agg_id: IdGenerator, // output - plan: LogicalPlan, + plan_stack: Vec>, // catalog & data flow data key_registry: name_resolver::KeyRegistry, @@ -235,7 +249,7 @@ impl<'a> AstToLogical<'a> { agg_id: Default::default(), // output - plan: Default::default(), + plan_stack: Default::default(), key_registry: registry, fnsym_tab, @@ -255,7 +269,14 @@ impl<'a> AstToLogical<'a> { errors: self.errors, }); } - Ok(self.plan) + if self.plan_stack.len() != 1 { + return Err(AstTransformationError { + errors: vec![AstTransformError::IllegalState( + "plan_stack.len() != 1".to_string(), + )], + }); + } + Ok(self.plan_stack.pop().expect("plan")) } #[inline] @@ -582,14 +603,16 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { fn enter_top_level_query(&mut self, _query: &'ast ast::TopLevelQuery) -> Traverse { self.enter_benv(); + self.plan_stack.push(LogicalPlan::default()); Traverse::Continue } fn exit_top_level_query(&mut self, _query: &'ast ast::TopLevelQuery) -> Traverse { let mut benv = self.exit_benv(); eq_or_fault!(self, benv.len(), 1, "Expect benv.len() == 1"); let out = benv.pop().unwrap(); - let sink_id = self.plan.add_operator(BindingsOp::Sink); - self.plan.add_flow(out, sink_id); + let plan = cur_plan_or_fault!(self); + let sink_id = plan.add_operator(BindingsOp::Sink); + plan.add_flow(out, sink_id); Traverse::Continue } @@ -597,6 +620,13 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { self.enter_benv(); if let QuerySet::Select(_) = query.set.node { self.enter_q(); + if self.q_stack.len() > 1 { + // entering subquery + // TODO: currently there's no enforcement for unique `OpId` between the outer query + // plan and the subquery plan. To decide if we want the `OpId`s to be unique + // between any subqueries. + self.plan_stack.push(LogicalPlan::default()); + } } Traverse::Continue } @@ -606,13 +636,23 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { match query.set.node { QuerySet::Select(_) => { let clauses = self.exit_q(); + let plan = cur_plan_or_fault!(self); let mut clauses = clauses.evaluation_order().into_iter(); if let Some(mut src_id) = clauses.next() { for dst_id in clauses { - self.plan.add_flow(src_id, dst_id); + plan.add_flow(src_id, dst_id); src_id = dst_id; } - self.push_bexpr(src_id); + if !self.q_stack.is_empty() { + // this is a subquery + true_or_fault!(self, !self.plan_stack.is_empty(), "plan_stack is empty"); + let plan = self.plan_stack.pop().unwrap(); + true_or_fault!(self, !self.vexpr_stack.is_empty(), "vexpr_stack is empty"); + self.push_vexpr(logical::ValueExpr::SubQueryExpr(SubQueryExpr { plan })); + // don't need bexpr for sink; will use outer query's bexpr + } else { + self.push_bexpr(src_id); + } } } _ => { @@ -622,8 +662,9 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { "benv.len() is not between 1 and 3" ); let mut out = *benv.first().unwrap(); + let plan = cur_plan_or_fault!(self); benv.into_iter().skip(1).for_each(|op| { - self.plan.add_flow(out, op); + plan.add_flow(out, op); out = op; }); self.push_bexpr(out); @@ -653,6 +694,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { fn exit_query_set(&mut self, query_set: &'ast QuerySet) -> Traverse { let env = self.exit_env(); let mut benv = self.exit_benv(); + let plan = cur_plan_or_fault!(self); match query_set { QuerySet::BagOp(bag_op) => { @@ -673,12 +715,12 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { SetQuantifier::Distinct => logical::SetQuantifier::Distinct, }; - let id = self.plan.add_operator(BindingsOp::BagOp(BagOp { + let id = plan.add_operator(BindingsOp::BagOp(BagOp { bag_op: bag_operator, setq, })); - self.plan.add_flow_with_branch_num(lid, id, 0); - self.plan.add_flow_with_branch_num(rid, id, 1); + plan.add_flow_with_branch_num(lid, id, 0); + plan.add_flow_with_branch_num(rid, id, 1); self.push_bexpr(id); } QuerySet::Select(_) => {} @@ -686,7 +728,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { eq_or_fault!(self, env.len(), 1, "env.len() != 1"); let expr = env.into_iter().next().unwrap(); let op = BindingsOp::ExprQuery(logical::ExprQuery { expr }); - let id = self.plan.add_operator(op); + let id = plan.add_operator(op); self.push_bexpr(id); } QuerySet::Values(_) => { @@ -728,8 +770,10 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let env = self.exit_env(); eq_or_fault!(self, env.len(), 0, "env.len() != 0"); + let plan = cur_plan_or_fault!(self); + if let Some(SetQuantifier::Distinct) = _projection.setq { - let id = self.plan.add_operator(BindingsOp::Distinct); + let id = plan.add_operator(BindingsOp::Distinct); self.current_clauses_mut().distinct.replace(id); } Traverse::Continue @@ -742,10 +786,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { } fn exit_projection_kind(&mut self, _projection_kind: &'ast ProjectionKind) -> Traverse { - let benv = self.exit_benv(); - if !benv.is_empty() { - not_yet_implemented_fault!(self, "Subquery within project".to_string()); - } + self.exit_benv(); let env = self.exit_env(); let select: BindingsOp = match _projection_kind { @@ -795,7 +836,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { logical::BindingsOp::ProjectValue(logical::ProjectValue { expr }) } }; - let id = self.plan.add_operator(select); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(select); self.current_clauses_mut().select_clause.replace(id); Traverse::Continue } @@ -906,24 +948,6 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { Traverse::Continue } - fn enter_in(&mut self, _in: &'ast ast::In) -> Traverse { - self.enter_env(); - Traverse::Continue - } - fn exit_in(&mut self, _in: &'ast ast::In) -> Traverse { - let mut env = self.exit_env(); - eq_or_fault!(self, env.len(), 2, "env.len() != 2"); - - let rhs = env.pop().unwrap(); - let lhs = env.pop().unwrap(); - self.push_vexpr(logical::ValueExpr::BinaryExpr( - logical::BinaryOp::In, - Box::new(lhs), - Box::new(rhs), - )); - Traverse::Continue - } - fn enter_like(&mut self, _like: &'ast Like) -> Traverse { self.enter_env(); Traverse::Continue @@ -1199,7 +1223,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { aggregate_exprs: self.aggregate_exprs.clone(), group_as_alias: None, }); - let id = self.plan.add_operator(group_by); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(group_by); self.current_clauses_mut().group_by_clause.replace(id); } Traverse::Continue @@ -1357,7 +1382,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { at_key, }), }; - let id = self.plan.add_operator(bexpr); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(bexpr); self.push_bexpr(id); Traverse::Continue } @@ -1393,17 +1419,18 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let rid = benv.pop().unwrap(); let lid = benv.pop().unwrap(); - let left = Box::new(self.plan.operator(lid).unwrap().clone()); - let right = Box::new(self.plan.operator(rid).unwrap().clone()); + let plan = cur_plan_or_fault!(self); + let left = Box::new(plan.operator(lid).unwrap().clone()); + let right = Box::new(plan.operator(rid).unwrap().clone()); let join = logical::BindingsOp::Join(logical::Join { kind, on, left, right, }); - let join = self.plan.add_operator(join); - self.plan.add_flow_with_branch_num(lid, join, 0); - self.plan.add_flow_with_branch_num(rid, join, 1); + let join = plan.add_operator(join); + plan.add_flow_with_branch_num(lid, join, 0); + plan.add_flow_with_branch_num(rid, join, 1); self.push_bexpr(join); Traverse::Continue } @@ -1435,7 +1462,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let filter = logical::BindingsOp::Filter(logical::Filter { expr: env.pop().unwrap(), }); - let id = self.plan.add_operator(filter); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(filter); self.current_clauses_mut().where_clause.replace(id); Traverse::Continue @@ -1453,7 +1481,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { let having = BindingsOp::Having(logical::Having { expr: env.pop().unwrap(), }); - let id = self.plan.add_operator(having); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(having); self.current_clauses_mut().having_clause.replace(id); Traverse::Continue @@ -1467,12 +1496,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { fn exit_group_by_expr(&mut self, _group_by_expr: &'ast GroupByExpr) -> Traverse { let aggregate_exprs = self.aggregate_exprs.clone(); - let benv = self.exit_benv(); - if !benv.is_empty() { - { - not_yet_implemented_fault!(self, "Subquery in group by".to_string()); - } - } + self.exit_benv(); let env = self.exit_env(); true_or_fault!(self, env.len().is_even(), "env.len() is not even"); @@ -1486,6 +1510,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { GroupingStrategy::GroupPartial => logical::GroupingStrategy::GroupPartial, }; + let mut binding = HashMap::new(); // What follows is an approach to implement section 11.2.1 of the PartiQL spec // (https://partiql.org/assets/PartiQL-Specification.pdf#subsubsection.11.2.1) // "Grouping Attributes and Direct Use of Grouping Expressions" @@ -1502,12 +1527,16 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { )); return Traverse::Stop; } - let select_clause = self - .plan - .operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None")) - .unwrap(); - let mut binding = HashMap::new(); - let select_clause_exprs = match select_clause { + let plan = cur_plan_or_fault!(self); + let select_clause = + plan.operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None")); + if select_clause.is_none() { + self.errors.push(AstTransformError::IllegalState( + "select_clause in plan not None".to_string(), + )); + return Traverse::Stop; + } + let select_clause_exprs = match select_clause.expect("select_clause is not None") { BindingsOp::Project(ref mut project) => &mut project.exprs, BindingsOp::ProjectAll => &mut binding, BindingsOp::ProjectValue(_) => &mut binding, // TODO: replacement of SELECT VALUE expressions @@ -1563,8 +1592,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { aggregate_exprs, group_as_alias, }); - - let id = self.plan.add_operator(group_by); + let id = plan.add_operator(group_by); self.current_clauses_mut().group_by_clause.replace(id); Traverse::Continue } @@ -1592,7 +1620,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { fn exit_order_by_expr(&mut self, _order_by_expr: &'ast OrderByExpr) -> Traverse { let specs = self.exit_sort(); let order_by = logical::BindingsOp::OrderBy(logical::OrderBy { specs }); - let id = self.plan.add_operator(order_by); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(order_by); if matches!(self.current_ctx(), Some(QueryContext::Query)) { self.current_clauses_mut().order_by_clause.replace(id); } else { @@ -1665,7 +1694,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { }; let limit_offset = logical::BindingsOp::LimitOffset(logical::LimitOffset { limit, offset }); - let id = self.plan.add_operator(limit_offset); + let plan = cur_plan_or_fault!(self); + let id = plan.add_operator(limit_offset); if matches!(self.current_ctx(), Some(QueryContext::Query)) { self.current_clauses_mut().limit_offset_clause.replace(id); } else {