Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subquery lowering #408

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `partiql_ast_passes::static_typer` for type annotating the AST.
- Add ability to parse `ORDER BY`, `LIMIT`, `OFFSET` in children of set operators
- Add `OUTER` bag operator (`OUTER UNION`, `OUTER INTERSECT`, `OUTER EXCEPT`) implementation
- Add ast to logical plan lowering for subqueries

### Fixes
- Fixes parsing of multiple consecutive path wildcards (e.g. `a[*][*][*]`), unpivot (e.g. `a.*.*.*`), and path expressions (e.g. `a[1 + 2][3 + 4][5 + 6]`)—previously these would not parse correctly.
Expand Down
4 changes: 2 additions & 2 deletions partiql-eval/src/eval/evaluable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ impl Evaluable for EvalOuterIntersect {
let bag: Bag = match self.setq {
SetQuantifier::All => {
let mut lhs = lhs.counts();
Bag::from_iter(rhs.filter(|elem| match lhs.get_mut(&elem) {
Bag::from_iter(rhs.filter(|elem| match lhs.get_mut(elem) {
Some(count) if *count > 0 => {
*count -= 1;
true
Expand Down Expand Up @@ -1477,7 +1477,7 @@ impl Evaluable for EvalOuterExcept {
let rhs = bagop_iter(self.r_input.take().unwrap_or(Missing));

let mut exclude = rhs.counts();
let excepted = lhs.filter(|elem| match exclude.get_mut(&elem) {
let excepted = lhs.filter(|elem| match exclude.get_mut(elem) {
Some(count) if *count > 0 => {
*count -= 1;
false
Expand Down
154 changes: 92 additions & 62 deletions partiql-logical-planner/src/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use partiql_logical as logical;
use partiql_logical::{
AggregateExpression, BagExpr, BagOp, BetweenExpr, BindingsOp, IsTypeExpr, LikeMatch,
LikeNonStringNonLiteralMatch, ListExpr, LogicalPlan, OpId, PathComponent, Pattern,
PatternMatchExpr, SortSpecOrder, TupleExpr, ValueExpr,
PatternMatchExpr, SortSpecOrder, SubQueryExpr, TupleExpr, ValueExpr,
};

use partiql_value::{BindingsName, Value};
Expand Down Expand Up @@ -80,6 +80,20 @@ macro_rules! not_yet_implemented_err {
};
}

#[macro_export]
macro_rules! cur_plan_or_fault {
($self:ident) => {{
let plan = $self.plan_stack.last_mut();
if plan.is_none() {
$self.errors.push(AstTransformError::IllegalState(
"current plan is None".to_string(),
));
return partiql_ast::visit::Traverse::Stop;
}
plan.unwrap()
}};
}

#[derive(Copy, Clone, Debug)]
enum QueryContext {
FromLet,
Expand Down Expand Up @@ -167,7 +181,7 @@ pub struct AstToLogical<'a> {
agg_id: IdGenerator,

// output
plan: LogicalPlan<BindingsOp>,
plan_stack: Vec<LogicalPlan<BindingsOp>>,

// catalog & data flow data
key_registry: name_resolver::KeyRegistry,
Expand Down Expand Up @@ -235,7 +249,7 @@ impl<'a> AstToLogical<'a> {
agg_id: Default::default(),

// output
plan: Default::default(),
plan_stack: Default::default(),

key_registry: registry,
fnsym_tab,
Expand All @@ -255,7 +269,14 @@ impl<'a> AstToLogical<'a> {
errors: self.errors,
});
}
Ok(self.plan)
if self.plan_stack.len() != 1 {
return Err(AstTransformationError {
errors: vec![AstTransformError::IllegalState(
"plan_stack.len() != 1".to_string(),
)],
});
}
Ok(self.plan_stack.pop().expect("plan"))
}

#[inline]
Expand Down Expand Up @@ -582,21 +603,30 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {

fn enter_top_level_query(&mut self, _query: &'ast ast::TopLevelQuery) -> Traverse {
self.enter_benv();
self.plan_stack.push(LogicalPlan::default());
Traverse::Continue
}
fn exit_top_level_query(&mut self, _query: &'ast ast::TopLevelQuery) -> Traverse {
let mut benv = self.exit_benv();
eq_or_fault!(self, benv.len(), 1, "Expect benv.len() == 1");
let out = benv.pop().unwrap();
let sink_id = self.plan.add_operator(BindingsOp::Sink);
self.plan.add_flow(out, sink_id);
let plan = cur_plan_or_fault!(self);
let sink_id = plan.add_operator(BindingsOp::Sink);
plan.add_flow(out, sink_id);
Traverse::Continue
}

fn enter_query(&mut self, query: &'ast Query) -> Traverse {
self.enter_benv();
if let QuerySet::Select(_) = query.set.node {
self.enter_q();
if self.q_stack.len() > 1 {
// entering subquery
// TODO: currently there's no enforcement for unique `OpId` between the outer query
// plan and the subquery plan. To decide if we want the `OpId`s to be unique
// between any subqueries.
Comment on lines +625 to +627
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subquery plan and outer query plan don't currently enforce unique OpIds between the two plans. I was going off of how subqueries were modeled in the logical and eval plan as part of #227 and example tests.

I was wondering if we want unique OpIds between the two plans, which can be helpful for debugging and plan graph visualization.

self.plan_stack.push(LogicalPlan::default());
}
}
Traverse::Continue
}
Expand All @@ -606,13 +636,23 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
match query.set.node {
QuerySet::Select(_) => {
let clauses = self.exit_q();
let plan = cur_plan_or_fault!(self);
let mut clauses = clauses.evaluation_order().into_iter();
if let Some(mut src_id) = clauses.next() {
for dst_id in clauses {
self.plan.add_flow(src_id, dst_id);
plan.add_flow(src_id, dst_id);
src_id = dst_id;
}
self.push_bexpr(src_id);
if !self.q_stack.is_empty() {
// this is a subquery
true_or_fault!(self, !self.plan_stack.is_empty(), "plan_stack is empty");
let plan = self.plan_stack.pop().unwrap();
true_or_fault!(self, !self.vexpr_stack.is_empty(), "vexpr_stack is empty");
self.push_vexpr(logical::ValueExpr::SubQueryExpr(SubQueryExpr { plan }));
// don't need bexpr for sink; will use outer query's bexpr
} else {
self.push_bexpr(src_id);
}
}
}
_ => {
Expand All @@ -622,8 +662,9 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
"benv.len() is not between 1 and 3"
);
let mut out = *benv.first().unwrap();
let plan = cur_plan_or_fault!(self);
benv.into_iter().skip(1).for_each(|op| {
self.plan.add_flow(out, op);
plan.add_flow(out, op);
out = op;
});
self.push_bexpr(out);
Expand Down Expand Up @@ -653,6 +694,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
fn exit_query_set(&mut self, query_set: &'ast QuerySet) -> Traverse {
let env = self.exit_env();
let mut benv = self.exit_benv();
let plan = cur_plan_or_fault!(self);

match query_set {
QuerySet::BagOp(bag_op) => {
Expand All @@ -673,20 +715,20 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
SetQuantifier::Distinct => logical::SetQuantifier::Distinct,
};

let id = self.plan.add_operator(BindingsOp::BagOp(BagOp {
let id = plan.add_operator(BindingsOp::BagOp(BagOp {
bag_op: bag_operator,
setq,
}));
self.plan.add_flow_with_branch_num(lid, id, 0);
self.plan.add_flow_with_branch_num(rid, id, 1);
plan.add_flow_with_branch_num(lid, id, 0);
plan.add_flow_with_branch_num(rid, id, 1);
self.push_bexpr(id);
}
QuerySet::Select(_) => {}
QuerySet::Expr(_) => {
eq_or_fault!(self, env.len(), 1, "env.len() != 1");
let expr = env.into_iter().next().unwrap();
let op = BindingsOp::ExprQuery(logical::ExprQuery { expr });
let id = self.plan.add_operator(op);
let id = plan.add_operator(op);
self.push_bexpr(id);
}
QuerySet::Values(_) => {
Expand Down Expand Up @@ -728,8 +770,10 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
let env = self.exit_env();
eq_or_fault!(self, env.len(), 0, "env.len() != 0");

let plan = cur_plan_or_fault!(self);

if let Some(SetQuantifier::Distinct) = _projection.setq {
let id = self.plan.add_operator(BindingsOp::Distinct);
let id = plan.add_operator(BindingsOp::Distinct);
self.current_clauses_mut().distinct.replace(id);
}
Traverse::Continue
Expand All @@ -742,10 +786,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
}

fn exit_projection_kind(&mut self, _projection_kind: &'ast ProjectionKind) -> Traverse {
let benv = self.exit_benv();
if !benv.is_empty() {
not_yet_implemented_fault!(self, "Subquery within project".to_string());
}
self.exit_benv();
let env = self.exit_env();

let select: BindingsOp = match _projection_kind {
Expand Down Expand Up @@ -795,7 +836,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
logical::BindingsOp::ProjectValue(logical::ProjectValue { expr })
}
};
let id = self.plan.add_operator(select);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(select);
self.current_clauses_mut().select_clause.replace(id);
Traverse::Continue
}
Expand Down Expand Up @@ -906,24 +948,6 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
Traverse::Continue
}

fn enter_in(&mut self, _in: &'ast ast::In) -> Traverse {
self.enter_env();
Traverse::Continue
}
fn exit_in(&mut self, _in: &'ast ast::In) -> Traverse {
let mut env = self.exit_env();
eq_or_fault!(self, env.len(), 2, "env.len() != 2");

let rhs = env.pop().unwrap();
let lhs = env.pop().unwrap();
self.push_vexpr(logical::ValueExpr::BinaryExpr(
logical::BinaryOp::In,
Box::new(lhs),
Box::new(rhs),
));
Traverse::Continue
}

fn enter_like(&mut self, _like: &'ast Like) -> Traverse {
self.enter_env();
Traverse::Continue
Expand Down Expand Up @@ -1199,7 +1223,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
aggregate_exprs: self.aggregate_exprs.clone(),
group_as_alias: None,
});
let id = self.plan.add_operator(group_by);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(group_by);
self.current_clauses_mut().group_by_clause.replace(id);
}
Traverse::Continue
Expand Down Expand Up @@ -1357,7 +1382,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
at_key,
}),
};
let id = self.plan.add_operator(bexpr);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(bexpr);
self.push_bexpr(id);
Traverse::Continue
}
Expand Down Expand Up @@ -1393,17 +1419,18 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {

let rid = benv.pop().unwrap();
let lid = benv.pop().unwrap();
let left = Box::new(self.plan.operator(lid).unwrap().clone());
let right = Box::new(self.plan.operator(rid).unwrap().clone());
let plan = cur_plan_or_fault!(self);
let left = Box::new(plan.operator(lid).unwrap().clone());
let right = Box::new(plan.operator(rid).unwrap().clone());
let join = logical::BindingsOp::Join(logical::Join {
kind,
on,
left,
right,
});
let join = self.plan.add_operator(join);
self.plan.add_flow_with_branch_num(lid, join, 0);
self.plan.add_flow_with_branch_num(rid, join, 1);
let join = plan.add_operator(join);
plan.add_flow_with_branch_num(lid, join, 0);
plan.add_flow_with_branch_num(rid, join, 1);
self.push_bexpr(join);
Traverse::Continue
}
Expand Down Expand Up @@ -1435,7 +1462,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
let filter = logical::BindingsOp::Filter(logical::Filter {
expr: env.pop().unwrap(),
});
let id = self.plan.add_operator(filter);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(filter);

self.current_clauses_mut().where_clause.replace(id);
Traverse::Continue
Expand All @@ -1453,7 +1481,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
let having = BindingsOp::Having(logical::Having {
expr: env.pop().unwrap(),
});
let id = self.plan.add_operator(having);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(having);

self.current_clauses_mut().having_clause.replace(id);
Traverse::Continue
Expand All @@ -1467,12 +1496,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {

fn exit_group_by_expr(&mut self, _group_by_expr: &'ast GroupByExpr) -> Traverse {
let aggregate_exprs = self.aggregate_exprs.clone();
let benv = self.exit_benv();
if !benv.is_empty() {
{
not_yet_implemented_fault!(self, "Subquery in group by".to_string());
}
}
self.exit_benv();
let env = self.exit_env();
true_or_fault!(self, env.len().is_even(), "env.len() is not even");

Expand All @@ -1486,6 +1510,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
GroupingStrategy::GroupPartial => logical::GroupingStrategy::GroupPartial,
};

let mut binding = HashMap::new();
// What follows is an approach to implement section 11.2.1 of the PartiQL spec
// (https://partiql.org/assets/PartiQL-Specification.pdf#subsubsection.11.2.1)
// "Grouping Attributes and Direct Use of Grouping Expressions"
Expand All @@ -1502,12 +1527,16 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
));
return Traverse::Stop;
}
let select_clause = self
.plan
.operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None"))
.unwrap();
let mut binding = HashMap::new();
let select_clause_exprs = match select_clause {
let plan = cur_plan_or_fault!(self);
let select_clause =
plan.operator_as_mut(select_clause_op_id.expect("select_clause_op_id not None"));
if select_clause.is_none() {
self.errors.push(AstTransformError::IllegalState(
"select_clause in plan not None".to_string(),
));
return Traverse::Stop;
}
let select_clause_exprs = match select_clause.expect("select_clause is not None") {
BindingsOp::Project(ref mut project) => &mut project.exprs,
BindingsOp::ProjectAll => &mut binding,
BindingsOp::ProjectValue(_) => &mut binding, // TODO: replacement of SELECT VALUE expressions
Expand Down Expand Up @@ -1563,8 +1592,7 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
aggregate_exprs,
group_as_alias,
});

let id = self.plan.add_operator(group_by);
let id = plan.add_operator(group_by);
self.current_clauses_mut().group_by_clause.replace(id);
Traverse::Continue
}
Expand Down Expand Up @@ -1592,7 +1620,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
fn exit_order_by_expr(&mut self, _order_by_expr: &'ast OrderByExpr) -> Traverse {
let specs = self.exit_sort();
let order_by = logical::BindingsOp::OrderBy(logical::OrderBy { specs });
let id = self.plan.add_operator(order_by);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(order_by);
if matches!(self.current_ctx(), Some(QueryContext::Query)) {
self.current_clauses_mut().order_by_clause.replace(id);
} else {
Expand Down Expand Up @@ -1665,7 +1694,8 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> {
};

let limit_offset = logical::BindingsOp::LimitOffset(logical::LimitOffset { limit, offset });
let id = self.plan.add_operator(limit_offset);
let plan = cur_plan_or_fault!(self);
let id = plan.add_operator(limit_offset);
if matches!(self.current_ctx(), Some(QueryContext::Query)) {
self.current_clauses_mut().limit_offset_clause.replace(id);
} else {
Expand Down