diff --git a/src/sql/execution/aggregate.rs b/src/sql/execution/aggregate.rs index 5e9edb130..ab6c48578 100644 --- a/src/sql/execution/aggregate.rs +++ b/src/sql/execution/aggregate.rs @@ -34,8 +34,14 @@ struct Aggregator { impl Aggregator { /// Creates a new aggregator for the given aggregates and GROUP BY buckets. fn new(aggregates: Vec, group_by: Vec) -> Self { + use Aggregate::*; let accumulators = aggregates.iter().map(Accumulator::new).collect(); - let exprs = aggregates.into_iter().map(|a| a.into_inner()).collect(); + let exprs = aggregates + .into_iter() + .map(|aggregate| match aggregate { + Average(expr) | Count(expr) | Max(expr) | Min(expr) | Sum(expr) => expr, + }) + .collect(); Self { buckets: BTreeMap::new(), empty: accumulators, group_by, exprs } } @@ -101,7 +107,7 @@ impl Accumulator { } /// Adds a value to the accumulator. - /// TODO: have this take &Value. + /// TODO: NULL values should possibly be ignored, not yield NULL (see Postgres?). fn add(&mut self, value: Value) -> Result<()> { use std::cmp::Ordering; match (self, value) { diff --git a/src/sql/parser/ast.rs b/src/sql/parser/ast.rs index 54ccb45d5..8e763b7cd 100644 --- a/src/sql/parser/ast.rs +++ b/src/sql/parser/ast.rs @@ -1,4 +1,3 @@ -use crate::error::Result; use crate::sql::types::DataType; use std::collections::BTreeMap; @@ -101,9 +100,6 @@ pub enum Order { pub enum Expression { /// A field reference, with an optional table qualifier. Field(Option, String), - /// A column index (only used during planning to break off subtrees). - /// TODO: get rid of this, planning shouldn't modify the AST. - Column(usize), /// A literal value. Literal(Literal), /// A function call (name and parameters). @@ -188,82 +184,6 @@ pub enum Operator { } impl Expression { - /// Transforms the expression tree depth-first by applying a closure before - /// and after descending. - /// - /// TODO: make closures non-mut. - pub fn transform(mut self, before: &mut B, after: &mut A) -> Result - where - B: FnMut(Self) -> Result, - A: FnMut(Self) -> Result, - { - use Operator::*; - self = before(self)?; - - // Helper for transforming a boxed expression. - let mut transform = |mut expr: Box| -> Result> { - *expr = expr.transform(before, after)?; - Ok(expr) - }; - - self = match self { - Self::Literal(_) | Self::Field(_, _) | Self::Column(_) => self, - - Self::Function(name, exprs) => Self::Function( - name, - exprs.into_iter().map(|e| e.transform(before, after)).collect::>()?, - ), - - Self::Operator(op) => Self::Operator(match op { - Add(lhs, rhs) => Add(transform(lhs)?, transform(rhs)?), - And(lhs, rhs) => And(transform(lhs)?, transform(rhs)?), - Divide(lhs, rhs) => Divide(transform(lhs)?, transform(rhs)?), - Equal(lhs, rhs) => Equal(transform(lhs)?, transform(rhs)?), - Exponentiate(lhs, rhs) => Exponentiate(transform(lhs)?, transform(rhs)?), - Factorial(expr) => Factorial(transform(expr)?), - GreaterThan(lhs, rhs) => GreaterThan(transform(lhs)?, transform(rhs)?), - GreaterThanOrEqual(lhs, rhs) => { - GreaterThanOrEqual(transform(lhs)?, transform(rhs)?) - } - Identity(expr) => Identity(transform(expr)?), - IsNaN(expr) => IsNaN(transform(expr)?), - IsNull(expr) => IsNull(transform(expr)?), - LessThan(lhs, rhs) => LessThan(transform(lhs)?, transform(rhs)?), - LessThanOrEqual(lhs, rhs) => LessThanOrEqual(transform(lhs)?, transform(rhs)?), - Like(lhs, rhs) => Like(transform(lhs)?, transform(rhs)?), - Modulo(lhs, rhs) => Modulo(transform(lhs)?, transform(rhs)?), - Multiply(lhs, rhs) => Multiply(transform(lhs)?, transform(rhs)?), - Negate(expr) => Negate(transform(expr)?), - Not(expr) => Not(transform(expr)?), - NotEqual(lhs, rhs) => NotEqual(transform(lhs)?, transform(rhs)?), - Or(lhs, rhs) => Or(transform(lhs)?, transform(rhs)?), - Subtract(lhs, rhs) => Subtract(transform(lhs)?, transform(rhs)?), - }), - }; - self = after(self)?; - Ok(self) - } - - /// Transforms an expression using a mutable reference. - /// TODO: try to get rid of this and replace_with(). - pub fn transform_mut(&mut self, before: &mut B, after: &mut A) -> Result<()> - where - B: FnMut(Self) -> Result, - A: FnMut(Self) -> Result, - { - self.replace_with(|e| e.transform(before, after)) - } - - /// Replaces the expression with result of the closure. Helper function for - /// transform(). - fn replace_with(&mut self, mut f: impl FnMut(Self) -> Result) -> Result<()> { - // Temporarily replace expression with a null value, in case closure panics. May consider - // replace_with crate if this hampers performance. - let expr = std::mem::replace(self, Expression::Literal(Literal::Null)); - *self = f(expr)?; - Ok(()) - } - /// Walks the expression tree depth-first, calling a closure for every node. /// Halts and returns false if the closure returns false. pub fn walk(&self, visitor: &mut impl FnMut(&Expression) -> bool) -> bool { @@ -297,7 +217,7 @@ impl Expression { Self::Function(_, exprs) => exprs.iter().any(|expr| expr.walk(visitor)), - Self::Literal(_) | Self::Field(_, _) | Self::Column(_) => true, + Self::Literal(_) | Self::Field(_, _) => true, } } @@ -348,7 +268,7 @@ impl Expression { Self::Function(_, exprs) => exprs.iter().for_each(|expr| expr.collect(visitor, c)), - Self::Literal(_) | Self::Field(_, _) | Self::Column(_) => {} + Self::Literal(_) | Self::Field(_, _) => {} } } } diff --git a/src/sql/planner/plan.rs b/src/sql/planner/plan.rs index 00f790f81..aa7e219e4 100644 --- a/src/sql/planner/plan.rs +++ b/src/sql/planner/plan.rs @@ -85,7 +85,9 @@ impl Plan { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum Node { /// Computes aggregate values for the given expressions and group_by buckets - /// across all rows in the source node. + /// across all rows in the source node. The aggregate columns are output + /// first, followed by the group_by columns, in the given order. + /// TODO: reverse the order. Aggregate { source: Box, aggregates: Vec, group_by: Vec }, /// Filters source rows, by only emitting rows for which the predicate /// evaluates to true. @@ -465,25 +467,6 @@ impl std::fmt::Display for Aggregate { } } -impl Aggregate { - /// Returns the inner aggregate expression. Currently, all aggregate - /// functions take a single input expression. - pub fn into_inner(self) -> Expression { - match self { - Self::Average(expr) - | Self::Count(expr) - | Self::Max(expr) - | Self::Min(expr) - | Self::Sum(expr) => expr, - } - } - - // TODO: get rid of this. - pub(super) fn is(name: &str) -> bool { - ["avg", "count", "max", "min", "sum"].contains(&name) - } -} - /// A sort order direction. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum Direction { diff --git a/src/sql/planner/planner.rs b/src/sql/planner/planner.rs index 515b7b8c8..c4bdf5ab3 100644 --- a/src/sql/planner/planner.rs +++ b/src/sql/planner/planner.rs @@ -148,11 +148,11 @@ impl<'a, C: Catalog> Planner<'a, C> { #[allow(clippy::too_many_arguments)] fn build_select( &self, - mut select: Vec<(ast::Expression, Option)>, + select: Vec<(ast::Expression, Option)>, from: Vec, r#where: Option, group_by: Vec, - mut having: Option, + having: Option, order: Vec<(ast::Expression, ast::Order)>, offset: Option, limit: Option, @@ -174,40 +174,27 @@ impl<'a, C: Catalog> Planner<'a, C> { node = Node::Filter { source: Box::new(node), predicate }; }; + // Build aggregate functions and GROUP BY clause. + let aggregates = Self::collect_aggregates(&select, &having, &order); + if !group_by.is_empty() || !aggregates.is_empty() { + node = self.build_aggregate(&mut scope, node, group_by, aggregates)?; + } + // Build SELECT clause. let mut hidden = 0; if !select.is_empty() { - // Inject hidden SELECT columns for fields and aggregates used in ORDER BY and - // HAVING expressions but not present in existing SELECT output. These will be - // removed again by a later projection. - if let Some(ref mut expr) = having { - hidden += self.inject_hidden(expr, &mut select)?; - } - - // Extract any aggregate functions and GROUP BY expressions and - // build an aggregation node for them, replacing them with Column - // placeholders. - // - // TODO: handle ORDER BY aggregates. - let aggregates = self.extract_aggregates(&scope, &mut select)?; - let groups = self.extract_groups(&mut select, group_by, aggregates.len())?; - if !aggregates.is_empty() || !groups.is_empty() { - node = self.build_aggregation(&mut scope, node, groups, aggregates)?; - } - - // Build the remaining non-aggregate projection. let labels = select.iter().map(|(_, l)| Label::maybe_name(l.clone())).collect_vec(); - let mut expressions = select + let mut expressions: Vec<_> = select .into_iter() .map(|(e, _)| Self::build_expression(e, &scope)) - .collect::>>()?; + .collect::>()?; let parent_scope = scope; scope = parent_scope.project(&expressions, &labels)?; - // Add hidden columns for any ORDER BY fields not in the projection. + // Add hidden columns for HAVING and ORDER BY fields not in SELECT. // TODO: track hidden fields in Scope. let size = expressions.len(); - for (expr, _) in &order { + for expr in having.iter().chain(order.iter().map(|(e, _)| e)) { self.build_hidden(&mut scope, &parent_scope, &mut expressions, expr); } hidden += expressions.len() - size; @@ -345,118 +332,123 @@ impl<'a, C: Catalog> Planner<'a, C> { Ok(node) } - /// Builds an aggregate node. + /// Builds an aggregate node, computing aggregate functions for a set of + /// GROUP BY buckets. + /// + /// The aggregate functions have been collected from the SELECT, HAVING, and + /// ORDER BY clauses (all of which can contain their own aggregate + /// functions). /// - /// TODO: revisit this. - fn build_aggregation( + /// The ast::Expression for each aggregate function and each GROUP BY + /// expression (except trivial column names) is stored in the Scope along + /// with the column index. Later nodes (i.e. SELECT, HAVING, and ORDER BY) + /// can look up the column index of aggregate expressions via Scope. + /// Similarly, they are allowed to reference GROUP BY expressions by + /// specifying the exact same expression. + /// + /// TODO: consider avoiding the expr cloning by taking &Expression in + /// various places. + fn build_aggregate( &self, scope: &mut Scope, source: Node, - groups: Vec<(ast::Expression, Option)>, - aggregates: Vec, + group_by: Vec, + aggregates: Vec<&ast::Expression>, ) -> Result { - let mut group_by = Vec::new(); - let mut expressions = Vec::new(); - let mut labels = Vec::new(); - for _ in &aggregates { - expressions.push(Expression::Constant(Value::Null)); - labels.push(Label::None); + // Construct a child scope with the group_by and aggregate AST + // expressions, such that downstream nodes can identify and reference + // them. Discard redundant expressions. + // + // TODO: reverse the order of the emitted columns: group_by then + // aggregates. + let mut child_scope = scope.project(&[], &[])?; // project to keep tables + let aggregates = aggregates + .into_iter() + .filter(|&expr| { + if child_scope.lookup_aggregate(expr).is_some() { + return false; + } + child_scope.add_aggregate((expr).clone(), Label::None); + true + }) + .collect_vec(); + let group_by = group_by + .into_iter() + .filter(|expr| { + if child_scope.lookup_aggregate(expr).is_some() { + return false; // already exists in child scope + } + let mut label = Label::None; + if let ast::Expression::Field(table, column) = expr { + if let Ok(index) = scope.lookup_column(table.as_deref(), column.as_str()) { + label = scope.get_column_label(index).unwrap(); + } + } + child_scope.add_aggregate(expr.clone(), label); + true + }) + .collect_vec(); + + // Build the node from the remaining expressions. + let aggregates = aggregates + .into_iter() + .map(|expr| Self::build_aggregate_function(scope, expr.clone())) + .collect::>()?; + let group_by = group_by + .into_iter() + .map(|expr| Self::build_expression(expr, scope)) + .collect::>()?; + + *scope = child_scope; + Ok(Node::Aggregate { source: Box::new(source), group_by, aggregates }) + } + + /// Builds an aggregate function from an AST expression. + fn build_aggregate_function(scope: &Scope, expr: ast::Expression) -> Result { + let ast::Expression::Function(name, mut args) = expr else { + panic!("aggregate expression must be function"); + }; + if args.len() != 1 { + return errinput!("{name} takes 1 argument"); } - for (expr, label) in groups { - let expr = Self::build_expression(expr, scope)?; - expressions.push(expr.clone()); - group_by.push(expr); - labels.push(Label::maybe_name(label)); + if args[0].contains(&|expr| Self::is_aggregate_function(expr)) { + return errinput!("aggregate functions can't be nested"); } - let node = Node::Aggregate { source: Box::new(source), group_by, aggregates }; - *scope = scope.project(&expressions, &labels)?; - Ok(node) + let expr = Self::build_expression(args.remove(0), scope)?; + Ok(match name.as_str() { + "avg" => Aggregate::Average(expr), + "count" => Aggregate::Count(expr), + "min" => Aggregate::Min(expr), + "max" => Aggregate::Max(expr), + "sum" => Aggregate::Sum(expr), + name => return errinput!("unknown aggregate function {name}"), + }) } - /// Extracts aggregate functions from an AST expression tree. This finds the aggregate - /// function calls, replaces them with ast::Expression::Column(i), maps the aggregate functions - /// to aggregates, and returns them along with their argument expressions. - fn extract_aggregates( - &self, - scope: &Scope, - exprs: &mut [(ast::Expression, Option)], - ) -> Result> { - let mut aggregates = Vec::new(); - for (expr, _) in exprs { - expr.transform_mut( - &mut |e| match e { - ast::Expression::Function(f, mut args) - if Aggregate::is(&f) && args.len() == 1 => - { - let expr = Self::build_expression(args.remove(0), scope)?; - aggregates.push(match f.as_str() { - "avg" => Aggregate::Average(expr), - "count" => Aggregate::Count(expr), - "min" => Aggregate::Min(expr), - "max" => Aggregate::Max(expr), - "sum" => Aggregate::Sum(expr), - f => panic!("invalid aggregate function {f}"), - }); - Ok(ast::Expression::Column(aggregates.len() - 1)) - } - _ => Ok(e), - }, - &mut Ok, - )?; + /// Checks whether a given AST expression is an aggregate function. + fn is_aggregate_function(expr: &ast::Expression) -> bool { + if let ast::Expression::Function(name, _) = expr { + ["avg", "count", "max", "min", "sum"].contains(&name.as_str()) + } else { + false } - Ok(aggregates) } - /// Extracts group by expressions, and replaces them with column references with the given - /// offset. These can be either an arbitray expression, a reference to a SELECT column, or the - /// same expression as a SELECT column. The following are all valid: - /// - /// SELECT released / 100 AS century, COUNT(*) FROM movies GROUP BY century - /// SELECT released / 100, COUNT(*) FROM movies GROUP BY released / 100 - /// SELECT COUNT(*) FROM movies GROUP BY released / 100 - fn extract_groups( - &self, - exprs: &mut [(ast::Expression, Option)], - group_by: Vec, - offset: usize, - ) -> Result)>> { - let mut groups = Vec::new(); - for g in group_by { - // Look for references to SELECT columns with AS labels - if let ast::Expression::Field(None, label) = &g { - if let Some(i) = exprs.iter().position(|(_, l)| l.as_deref() == Some(label)) { - groups.push(( - std::mem::replace( - &mut exprs[i].0, - ast::Expression::Column(offset + groups.len()), - ), - exprs[i].1.clone(), - )); - continue; - } - } - // Look for expressions exactly equal to the group expression - if let Some(i) = exprs.iter().position(|(e, _)| e == &g) { - groups.push(( - std::mem::replace( - &mut exprs[i].0, - ast::Expression::Column(offset + groups.len()), - ), - exprs[i].1.clone(), - )); - continue; - } - // Otherwise, just use the group expression directly - groups.push((g, None)) - } - // Make sure no group expressions contain Column references, which would be placed here - // during extract_aggregates(). - for (expr, _) in &groups { - if Self::is_aggregate(expr) { - return errinput!("group expression cannot contain aggregates"); - } + /// Collects aggregate functions from SELECT, HAVING, and ORDER BY clauses. + fn collect_aggregates<'c>( + select: &'c [(ast::Expression, Option)], + having: &'c Option, + order_by: &'c [(ast::Expression, ast::Order)], + ) -> Vec<&'c ast::Expression> { + let select = select.iter().map(|(expr, _)| expr); + let having = having.iter(); + let order_by = order_by.iter().map(|(expr, _)| expr); + + let mut aggregates = Vec::new(); + for expr in select.chain(having).chain(order_by) { + expr.collect(&|e| Self::is_aggregate_function(e), &mut aggregates) } - Ok(groups) + aggregates } /// Adds hidden columns to a projection to pass through fields that are used @@ -477,9 +469,20 @@ impl<'a, C: Catalog> Planner<'a, C> { projection: &mut Vec, expr: &ast::Expression, ) { - expr.walk(&mut |e| { - // Look for field references. - let ast::Expression::Field(table, name) = e else { + expr.walk(&mut |expr| { + // If this is an aggregate function or GROUP BY expression that + // isn't already available in the child scope, pass it through. + if let Some(index) = parent_scope.lookup_aggregate(expr) { + if scope.lookup_aggregate(expr).is_none() { + let label = parent_scope.get_column_label(index).unwrap(); + scope.add_aggregate(expr.clone(), label); + projection.push(Expression::Field(index, Label::None)); + return true; + } + } + + // Otherwise, only look for field references. + let ast::Expression::Field(table, name) = expr else { return true; }; // If the field already exists post-projection, do nothing. @@ -492,79 +495,27 @@ impl<'a, C: Catalog> Planner<'a, C> { let Ok(index) = parent_scope.lookup_column(table.as_deref(), name) else { return true; }; - // Add a hidden column to the projection. - let label = Label::maybe_qualified(table.clone(), name.clone()); - scope.add_column(label.clone()); - projection.push(Expression::Field(index, label)); + // Add a hidden column to the projection. Use the given label for + // the projection, but the qualified label for the scope. + scope.add_column(parent_scope.get_column_label(index).unwrap()); + projection.push(Expression::Field( + index, + Label::maybe_qualified(table.clone(), name.clone()), + )); true }); } - /// Injects hidden expressions into SELECT expressions. This is used for ORDER BY and HAVING, in - /// order to apply these to fields or aggregates that are not present in the SELECT output, e.g. - /// to order on a column that is not selected. This is done by replacing the relevant parts of - /// the given expression with Column references to either existing columns or new, hidden - /// columns in the select expressions. Returns the number of hidden columns added. - fn inject_hidden( - &self, - expr: &mut ast::Expression, - select: &mut Vec<(ast::Expression, Option)>, - ) -> Result { - // Replace any identical expressions or label references with column references. - // - // TODO: instead of trying to deduplicate here, before the optimizer has - // normalized expressions and such, we should just go ahead and add new - // columns for all fields and expressions, and have a separate optimizer - // that looks at duplicate expressions in a single projection and - // collapses them, rewriting downstream field references. - for (i, (sexpr, label)) in select.iter().enumerate() { - if expr == sexpr { - *expr = ast::Expression::Column(i); - continue; - } - if let Some(label) = label { - expr.transform_mut( - &mut |e| match e { - ast::Expression::Field(None, ref l) if l == label => { - Ok(ast::Expression::Column(i)) - } - e => Ok(e), - }, - &mut Ok, - )?; - } - } - // Any remaining aggregate functions and field references must be extracted as hidden - // columns. - let mut hidden = 0; - expr.transform_mut( - &mut |e| match &e { - ast::Expression::Function(f, a) if Aggregate::is(f) => { - if let ast::Expression::Column(c) = a[0] { - if Self::is_aggregate(&select[c].0) { - return errinput!("aggregate function cannot reference aggregate"); - } - } - select.push((e, None)); - hidden += 1; - Ok(ast::Expression::Column(select.len() - 1)) - } - ast::Expression::Field(_, _) => { - select.push((e, None)); - hidden += 1; - Ok(ast::Expression::Column(select.len() - 1)) - } - _ => Ok(e), - }, - &mut Ok, - )?; - Ok(hidden) - } - /// Builds an expression from an AST expression. pub fn build_expression(expr: ast::Expression, scope: &Scope) -> Result { use Expression::*; + // Look up aggregate functions or GROUP BY expressions. These were added + // to the parent scope when building the Aggregate node, if any. + if let Some(index) = scope.lookup_aggregate(&expr) { + return Ok(Field(index, scope.get_column_label(index)?)); + } + // Helper for building a boxed expression. let build = |expr: Box| -> Result> { Ok(Box::new(Self::build_expression(*expr, scope)?)) @@ -578,13 +529,11 @@ impl<'a, C: Catalog> Planner<'a, C> { ast::Literal::Float(f) => Value::Float(f), ast::Literal::String(s) => Value::String(s), }), - ast::Expression::Column(i) => Field(i, scope.get_column_label(i)?), ast::Expression::Field(table, name) => Field( scope.lookup_column(table.as_deref(), &name)?, Label::maybe_qualified(table, name), ), - // All functions are currently aggregate functions, which should be - // processed separately. + // Currently, all functions are aggregates, and processed above. // TODO: consider adding some basic functions for fun. ast::Expression::Function(name, _) => return errinput!("unknown function {name}"), ast::Expression::Operator(op) => match op { @@ -625,11 +574,6 @@ impl<'a, C: Catalog> Planner<'a, C> { fn evaluate_constant(expr: ast::Expression) -> Result { Self::build_expression(expr, &Scope::new())?.evaluate(None) } - - /// Checks whether a given expression is an aggregate expression. - fn is_aggregate(expr: &ast::Expression) -> bool { - expr.contains(&|e| matches!(e, ast::Expression::Function(f, _) if Aggregate::is(f))) - } } /// A scope maps column/table names to input column indexes for expressions, @@ -641,6 +585,9 @@ impl<'a, C: Catalog> Planner<'a, C> { /// currently visible and what names they have. During expression planning, the /// scope is used to resolve column names to column indexes, which are placed in /// the plan and used during execution. +/// +/// It also keeps track of output columns for aggregate functions and GROUP BY +/// expressions in Aggregate nodes. See aggregates field. pub struct Scope { /// The currently visible columns. If empty, only constant expressions can /// be used (no field references). @@ -653,6 +600,22 @@ pub struct Scope { /// Index of unqualified column names to column indexes. If a name points /// to multiple columns, lookups will fail with an ambiguous name error. unqualified: HashMap>, + /// Index of aggregate expressions to column indexes. This is used to track + /// output columns of Aggregate nodes, e.g. SUM(2 * a + b), and look them up + /// from expressions in downstream SELECT, HAVING, and ORDER BY columns, + /// e.g. SELECT SUM(2 * a + b) / COUNT(*) FROM table. When build_expression + /// encounters an aggregate function, it's mapped onto an aggregate column + /// index via this index. + /// + /// This is also used to map GROUP BY expressions to the corresponding + /// Aggregate node output column when evaluating downstream node expressions + /// in SELECT, HAVING, and ORDER BY. For trivial column references, e.g. + /// GROUP BY a, b, the column can be accessed and looked up as normal via + /// lookup_column() in downstream node expressions, but for more complex + /// expressions like GROUP BY a * b / 2, the group column can be accessed + /// using the same expression in other nodes, e.g. GROUP BY a * b / 2 ORDER + /// BY a * b / 2. + aggregates: HashMap, } impl Scope { @@ -663,6 +626,7 @@ impl Scope { tables: HashSet::new(), qualified: HashMap::new(), unqualified: HashMap::new(), + aggregates: HashMap::new(), } } @@ -701,26 +665,30 @@ impl Scope { /// Looks up a column index by name, if possible. fn lookup_column(&self, table: Option<&str>, name: &str) -> Result { + let fmtname = || table.map(|table| format!("{table}.{name}")).unwrap_or(name.to_string()); if self.columns.is_empty() { - let field = table.map(|t| format!("{t}.{name}")).unwrap_or(name.to_string()); - return errinput!("expression must be constant, found field {field}"); + return errinput!("expression must be constant, found field {}", fmtname()); } if let Some(table) = table { if !self.tables.contains(table) { return errinput!("unknown table {table}"); } - self.qualified - .get(&(table.to_string(), name.to_string())) - .copied() - .ok_or(errinput!("unknown field {table}.{name}")) + if let Some(index) = self.qualified.get(&(table.to_string(), name.to_string())) { + return Ok(*index); + } } else if let Some(indexes) = self.unqualified.get(name) { if indexes.len() > 1 { return errinput!("ambiguous field {name}"); } - Ok(indexes[0]) - } else { - errinput!("unknown field {name}") + return Ok(indexes[0]); } + if !self.aggregates.is_empty() { + return errinput!( + "field {} must be used in an aggregate or GROUP BY expression", + fmtname() + ); + } + errinput!("unknown field {}", fmtname()) } /// Fetches a column label by index, if any. @@ -735,6 +703,31 @@ impl Scope { } } + /// Adds an aggregate expression to the scope, returning the column index. + /// This is either an aggregate function or a GROUP BY expression (i.e. not + /// just a simple column name). It is used to access the aggregate output or + /// GROUP BY column in downstream nodes like SELECT, HAVING, and ORDER BY. + /// + /// If the expression already exists, the current index is returned. + fn add_aggregate(&mut self, expr: ast::Expression, label: Label) -> usize { + if let Some(index) = self.aggregates.get(&expr) { + return *index; + } + let index = self.add_column(label); + self.aggregates.insert(expr, index); + index + } + + /// Looks up an aggregate column index by aggregate function or GROUP BY + /// expression, if any. Trivial GROUP BY column names are accessed via + /// lookup_column() as normal. + /// + /// Unlike lookup_column(), this returns an option since the caller is + /// expected to fall back to normal expressions building. + fn lookup_aggregate(&self, expr: &ast::Expression) -> Option { + self.aggregates.get(expr).copied() + } + /// Number of columns currently in the scope. fn len(&self) -> usize { self.columns.len() diff --git a/src/sql/testscripts/queries/aggregate b/src/sql/testscripts/queries/aggregate index 9f3725b87..8e18ca2af 100644 --- a/src/sql/testscripts/queries/aggregate +++ b/src/sql/testscripts/queries/aggregate @@ -17,7 +17,7 @@ ok # COUNT(*) returns the row count. -# TODO: revisit the plan here. +# TODO: revisit the plan here. This can be eliminated by short-circuiting optimizer. [plan]> SELECT COUNT(*) FROM test --- Projection: #0 @@ -26,7 +26,7 @@ Projection: #0 6 # COUNT works on constant values. -# TODO: revisit the plan here. +# TODO: revisit the plan here. This can be eliminated by short-circuiting optimizer. [plan,header]> SELECT COUNT(NULL), COUNT(TRUE), COUNT(1), COUNT(3.14), COUNT(NAN), COUNT('') --- Projection: #0, #1, #2, #3, #4, #5 @@ -44,7 +44,7 @@ Projection: #0, #1, #2, #3 0, 0, 0, 0 # COUNT returns number of non-NULL values. -# TODO: revisit the plan here, the last projection is unnecessary. +# TODO: revisit the plan here. This can be eliminated by short-circuiting optimizer. [plan,header]> SELECT COUNT(id), COUNT("bool"), COUNT("float"), COUNT("string") FROM test --- Projection: #0, #1, #2, #3 @@ -283,14 +283,12 @@ Projection: #0, #1, #2, #3, #4 └─ Scan: test 6, 1, 1, 6, 1 -# Constant aggregates can be used with rows including values. -# TODO: this doesn't work with SELECT *. It also doesn't work with fields. -# It shouldn't, but the error message count be better. +# Constant aggregates can't be used with value rows. [plan]!> SELECT *, COUNT(1), MIN(1), MAX(1), SUM(1), AVG(1) FROM test [plan]!> SELECT id, COUNT(1), MIN(1), MAX(1), SUM(1), AVG(1) FROM test --- Error: invalid input: unexpected token , -Error: invalid input: unknown field id +Error: invalid input: field id must be used in an aggregate or GROUP BY expression # Aggregate can be expression, both inside and outside the aggregate. [plan]> SELECT SUM("int" * 10) / COUNT("int") + 7 FROM test WHERE "int" IS NOT NULL @@ -301,7 +299,6 @@ Projection: #0 / #1 + 7 117 # Aggregate functions can't be nested. -# TODO: improve the error message here. !> SELECT MIN(MAX("int")) FROM test --- -Error: invalid input: unknown function max +Error: invalid input: aggregate functions can't be nested diff --git a/src/sql/testscripts/queries/group_by b/src/sql/testscripts/queries/group_by index 6576c0388..9ede52c85 100644 --- a/src/sql/testscripts/queries/group_by +++ b/src/sql/testscripts/queries/group_by @@ -28,7 +28,7 @@ Projection: #0, #1, #2, #3, #4 └─ Aggregate: count(id), min(id), max(id), sum(id), avg(id) group by id └─ Nothing -# Simple GROUP BY. +# Simple GROUP BY, including NULL group. [plan]> SELECT "group", COUNT(*) FROM test GROUP BY "group" --- Projection: test.group, #0 @@ -39,39 +39,32 @@ a, 3 b, 3 [plan]> SELECT "group", COUNT(*), MIN("bool"), MAX("string"), SUM("int"), AVG("float") \ - FROM test WHERE id != 0 GROUP BY "group" + FROM test GROUP BY "group" --- Projection: test.group, #0, #1, #2, #3, #4 └─ Aggregate: count(TRUE), min(bool), max(string), sum(int), avg(float) group by group - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL, 1, NULL, NULL, NULL, NULL a, 3, FALSE, AB, 9, NaN b, 3, FALSE, 👋, 41, NaN -# GROUP BY works with a NULL group. -[plan]> SELECT "group", COUNT(*) FROM test GROUP BY "group" ---- -Projection: test.group, #0 -└─ Aggregate: count(TRUE) group by group - └─ Scan: test -NULL, 1 -a, 3 -b, 3 - # GROUP BY works on booleans. -[plan]> SELECT "bool", COUNT(*) FROM test WHERE id != 0 GROUP BY "bool" +[plan]> SELECT "bool", COUNT(*) FROM test GROUP BY "bool" --- Projection: test.bool, #0 └─ Aggregate: count(TRUE) group by bool - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL, 1 FALSE, 3 TRUE, 3 # GROUP BY works on integers. -[plan]> SELECT "int", COUNT(*) FROM test WHERE id != 0 GROUP BY "int" +[plan]> SELECT "int", COUNT(*) FROM test GROUP BY "int" --- Projection: test.int, #0 └─ Aggregate: count(TRUE) group by int - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL, 1 -1, 2 0, 1 3, 1 @@ -79,54 +72,53 @@ Projection: test.int, #0 42, 1 # GROUP BY works with floats, including a NAN group and -0.0 and 0.0 being equal. -[plan]> SELECT "float", COUNT(*) FROM test WHERE id != 0 GROUP BY "float" +[plan]> SELECT "float", COUNT(*) FROM test GROUP BY "float" --- Projection: test.float, #0 └─ Aggregate: count(TRUE) group by float - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL, 1 0, 2 3.14, 1 inf, 1 NaN, 2 # GROUP BY works on strings. -[plan]> SELECT "string", COUNT(*) FROM test WHERE id != 0 GROUP BY "string" +[plan]> SELECT "string", COUNT(*) FROM test GROUP BY "string" --- Projection: test.string, #0 └─ Aggregate: count(TRUE) group by string - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL, 1 , 2 AB, 1 abc, 2 👋, 1 # GROUP BY works even if the group column isn't in the result. -[plan]> SELECT COUNT(*) FROM test WHERE id != 0 GROUP BY "group" +[plan]> SELECT COUNT(*) FROM test GROUP BY "group" --- Projection: #0 └─ Aggregate: count(TRUE) group by group - └─ Scan: test (NOT id = 0) + └─ Scan: test +1 3 3 # GROUP BY works when there is no aggregate function. -[plan]> SELECT "group" FROM test WHERE id != 0 GROUP BY "group" +[plan]> SELECT "group" FROM test GROUP BY "group" --- Projection: test.group └─ Aggregate: group by group - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL a b -# GROUP BY works on aliases, in which case the original table is unknown. -[plan]> SELECT "group" AS g, COUNT(*) FROM test GROUP BY g +# GROUP BY does not work with SELECT aliases (also the case in e.g. SQL server). +!> SELECT "group" AS g, COUNT(*) FROM test GROUP BY g --- -Projection: g, #0 -└─ Aggregate: count(TRUE) group by group - └─ Scan: test -NULL, 1 -a, 3 -b, 3 +Error: invalid input: unknown field g [plan]> SELECT "group", COUNT(*) FROM test AS t GROUP BY t."group" --- @@ -142,14 +134,13 @@ b, 3 Error: invalid input: unknown table test # It errors when there is a non-group column. -# TODO: improve the error message. -!> SELECT "group", id FROM test WHERE id != 0 GROUP BY "group" +!> SELECT "group", id FROM test GROUP BY "group" --- -Error: invalid input: unknown field id +Error: invalid input: field id must be used in an aggregate or GROUP BY expression # It errors on unknown tables and columns. -!> SELECT COUNT(*) FROM test WHERE id != 0 GROUP BY unknown -!> SELECT COUNT(*) FROM test WHERE id != 0 GROUP BY unknown.id +!> SELECT COUNT(*) FROM test GROUP BY unknown +!> SELECT COUNT(*) FROM test GROUP BY unknown.id --- Error: invalid input: unknown field unknown Error: invalid input: unknown table unknown @@ -179,32 +170,29 @@ Projection: #1, #0 0, 4 1, 3 -# GROUP BY can use an aliased expression. -[plan]> SELECT id % 2 AS mod, COUNT(*) FROM test GROUP BY mod +# GROUP BY can't use an aliased expression. +!> SELECT id % 2 AS mod, COUNT(*) FROM test GROUP BY mod --- -Projection: mod, #0 -└─ Aggregate: count(TRUE) group by id % 2 - └─ Scan: test -0, 4 -1, 3 +Error: invalid input: unknown field mod # GROUP BY can't use aggregate functions. !> SELECT COUNT(*) FROM test GROUP BY MIN(id) --- -Error: invalid input: group expression cannot contain aggregates +Error: invalid input: unknown function min # GROUP BY works with multiple groups. -[plan]> SELECT "group", "bool", COUNT(*) FROM test WHERE id != 0 GROUP BY "group", "bool" +[plan]> SELECT "group", "bool", COUNT(*) FROM test GROUP BY "group", "bool" --- Projection: test.group, test.bool, #0 └─ Aggregate: count(TRUE) group by group, bool - └─ Scan: test (NOT id = 0) + └─ Scan: test +NULL, NULL, 1 a, FALSE, 1 a, TRUE, 2 b, FALSE, 2 b, TRUE, 1 -# GROUP BY work with joins. +# GROUP BY works with joins. [plan]> SELECT t.id % 2, COUNT(*) FROM test t JOIN other o ON t.id % 2 = o.id GROUP BY t.id % 2 --- Projection: #1, #0 diff --git a/src/sql/testscripts/queries/having b/src/sql/testscripts/queries/having index 82665b9d8..22a23aca3 100644 --- a/src/sql/testscripts/queries/having +++ b/src/sql/testscripts/queries/having @@ -32,21 +32,37 @@ Scan: test (id > 3) [plan]> SELECT id FROM test HAVING "int" > 3 --- Projection: #0 -└─ Filter: test.int > 3 +└─ Filter: int > 3 └─ Projection: id, int └─ Scan: test 4 5 -# Having works with an aggregate function. -# TODO: it's unnecessary do duplicate the aggregation here. -# TODO: test without a GROUP BY clause. +# Having works with an aggregate function, even if it's not in SELECT. [plan]> SELECT "group", MAX("int") FROM test GROUP BY "group" HAVING MAX("int") > 10 --- Projection: #0, #1 └─ Filter: #2 > 10 - └─ Projection: test.group, #0, #1 - └─ Aggregate: max(int), max(int) group by group + └─ Projection: test.group, #0, #0 + └─ Aggregate: max(int) group by group + └─ Scan: test +b, 42 + +[plan]> SELECT "group" FROM test GROUP BY "group" HAVING MAX("int") > 10 +--- +Projection: #0 +└─ Filter: #1 > 10 + └─ Projection: test.group, #0 + └─ Aggregate: max(int) group by group + └─ Scan: test +b + +[plan]> SELECT "group", MAX("int") FROM test GROUP BY "group" HAVING MAX("int") - MIN("int") > 10 +--- +Projection: #0, #1 +└─ Filter: #2 - #3 > 10 + └─ Projection: test.group, #0, #0, #1 + └─ Aggregate: max(int), min(int) group by group └─ Scan: test b, 42 @@ -74,16 +90,47 @@ b, 3 --- Projection: #0, #1 └─ Filter: #2 / #3 > 3 - └─ Projection: test.group, #0, #1, #2 - └─ Aggregate: count(TRUE), max(int), count(TRUE) group by group + └─ Projection: test.group, #0, #1, #0 + └─ Aggregate: count(TRUE), max(int) group by group └─ Scan: test b, 3 +# Having works with compound expressions using complex GROUP BY expressions +# that are not on the SELECT clause. +[plan]> SELECT COUNT(*) FROM test GROUP BY id % 2 HAVING 2 - id % 2 + 1 > 1 +--- +Projection: #0 +└─ Filter: 2 - #1 + 1 > 1 + └─ Projection: #0, #1 + └─ Aggregate: count(TRUE) group by id % 2 + └─ Scan: test +4 +3 + +# Having can use (un)qualified expressions for an (un)qualified GROUP BY. +[plan]> SELECT COUNT(*) FROM test GROUP BY "group" HAVING test."group" = 'a' +--- +Projection: #0 +└─ Filter: test.group = a + └─ Projection: #0, test.group + └─ Aggregate: count(TRUE) group by group + └─ Scan: test +3 + +[plan]> SELECT COUNT(*) FROM test GROUP BY test."group" HAVING "group" = 'a' +--- +Projection: #0 +└─ Filter: group = a + └─ Projection: #0, group + └─ Aggregate: count(TRUE) group by test.group + └─ Scan: test +3 + # Having errors on nested aggregate functions. # TODO: fix the error message. !> SELECT "group", COUNT(*) FROM test GROUP BY "group" HAVING MAX(MIN("int")) > 0 --- -Error: invalid input: unknown function min +Error: invalid input: aggregate functions can't be nested # Having errors on fields not in the SELECT or GROUP BY clauses. # TODO: improve the error message. diff --git a/src/sql/testscripts/queries/order b/src/sql/testscripts/queries/order index f68cdded4..7578bb084 100644 --- a/src/sql/testscripts/queries/order +++ b/src/sql/testscripts/queries/order @@ -470,11 +470,70 @@ Order: o.id desc, t.id asc 9, NULL, NULL, NULL, 👍, 1, 1, a # Order by aggregates, both when in SELECT and otherwise. -# TODO: fix these. -[plan]!> SELECT "bool", MAX("int") FROM test GROUP BY "bool" ORDER BY MAX("int") DESC +[plan]> SELECT "bool", MAX("int") FROM test GROUP BY "bool" ORDER BY MAX("int") DESC --- -Error: invalid input: unknown function max +Projection: #0, #1 +└─ Order: #2 desc + └─ Projection: test.bool, #0, #0 + └─ Aggregate: max(int) group by bool + └─ Scan: test +TRUE, 0 +FALSE, -1 +NULL, NULL + +[plan]> SELECT "bool" FROM test GROUP BY "bool" ORDER BY MAX("int") DESC +--- +Projection: #0 +└─ Order: #1 desc + └─ Projection: test.bool, #0 + └─ Aggregate: max(int) group by bool + └─ Scan: test +TRUE +FALSE +NULL + +[plan]> SELECT "bool", MAX("int") FROM test GROUP BY "bool" ORDER BY MAX("int") - MIN("int") DESC +--- +Projection: #0, #1 +└─ Order: #2 - #3 desc + └─ Projection: test.bool, #0, #0, #1 + └─ Aggregate: max(int), min(int) group by bool + └─ Scan: test +FALSE, -1 +TRUE, 0 +NULL, NULL + +# ORDER BY works with compound expressions using complex GROUP BY expressions +# that are not on the SELECT clause. +[plan]> SELECT COUNT(*) FROM test GROUP BY id % 2 ORDER BY 2 - id % 2 + 1 > 1 +--- +Projection: #0 +└─ Order: 2 - #1 + 1 > 1 asc + └─ Projection: #0, #1 + └─ Aggregate: count(TRUE) group by id % 2 + └─ Scan: test +5 +5 -[plan]!> SELECT "bool" FROM test GROUP BY "bool" ORDER BY MAX("int") DESC +# ORDER BY can use (un)qualified expressions for an (un)qualified GROUP BY. +[plan]> SELECT COUNT(*) FROM test GROUP BY "bool" ORDER BY test."bool" --- -Error: invalid input: unknown function max +Projection: #0 +└─ Order: test.bool asc + └─ Projection: #0, test.bool + └─ Aggregate: count(TRUE) group by bool + └─ Scan: test +8 +1 +1 + +[plan]> SELECT COUNT(*) FROM test GROUP BY test."bool" ORDER BY "bool" +--- +Projection: #0 +└─ Order: bool asc + └─ Projection: #0, bool + └─ Aggregate: count(TRUE) group by test.bool + └─ Scan: test +8 +1 +1