Skip to content

Commit

Permalink
perf: Eliminate duplicate aggregations (#132)
Browse files Browse the repository at this point in the history
* perf: Eliminate duplicate aggregations

* code fmt
  • Loading branch information
KKould authored Feb 7, 2024
1 parent b11a8c9 commit bb9ccef
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
for column_name in column_names.iter().map(|ident| ident.value.to_lowercase()) {
if let Some(column) = columns
.iter_mut()
.find(|column| column.name() == column_name.to_string())
.find(|column| column.name() == column_name)
{
if *is_primary {
column.desc.is_primary = true;
Expand Down
2 changes: 1 addition & 1 deletion src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
.find(|(_, column)| column.desc.is_primary)
.map(|(_, column)| Arc::clone(column))
.unwrap();
let mut plan = ScanOperator::build(table_name.clone(), &table_catalog);
let mut plan = ScanOperator::build(table_name.clone(), table_catalog);

if let Some(alias) = alias {
self.context
Expand Down
4 changes: 2 additions & 2 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
table_name: TableName,
) -> Result<(), DatabaseError> {
if !alias_column.is_empty() {
let aliases = alias_column.into_iter().map(lower_ident).collect_vec();
let aliases = alias_column.iter().map(lower_ident).collect_vec();
let table = self
.context
.table(table_name.clone())
Expand Down Expand Up @@ -222,7 +222,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let table_name = Arc::new(table.to_string());

let table_catalog = self.context.table_and_bind(table_name.clone(), join_type)?;
let scan_op = ScanOperator::build(table_name.clone(), &table_catalog);
let scan_op = ScanOperator::build(table_name.clone(), table_catalog);

if let Some(TableAlias { name, columns }) = alias {
self.register_alias(columns, name.value.to_lowercase(), table_name.clone())?;
Expand Down
1 change: 1 addition & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl<S: Storage> Database<S> {
HepBatchStrategy::fix_point_topdown(10),
vec![
NormalizationRuleImpl::CollapseProject,
NormalizationRuleImpl::CollapseGroupByAgg,
NormalizationRuleImpl::CombineFilter,
],
)
Expand Down
79 changes: 79 additions & 0 deletions src/optimizer/rule/normalization/combine_operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::optimizer::rule::normalization::is_subset_exprs;
use crate::planner::operator::Operator;
use crate::types::LogicalType;
use lazy_static::lazy_static;
use std::collections::HashSet;

lazy_static! {
static ref COLLAPSE_PROJECT_RULE: Pattern = {
Expand All @@ -27,6 +28,21 @@ lazy_static! {
}]),
}
};
static ref COLLAPSE_GROUP_BY_AGG: Pattern = {
Pattern {
predicate: |op| match op {
Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(),
_ => false,
},
children: PatternChildrenPredicate::Predicate(vec![Pattern {
predicate: |op| match op {
Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(),
_ => false,
},
children: PatternChildrenPredicate::None,
}]),
}
};
}

/// Combine two adjacent project operators into one.
Expand Down Expand Up @@ -87,6 +103,47 @@ impl NormalizationRule for CombineFilter {
}
}

pub struct CollapseGroupByAgg;

impl MatchPattern for CollapseGroupByAgg {
fn pattern(&self) -> &Pattern {
&COLLAPSE_GROUP_BY_AGG
}
}

impl NormalizationRule for CollapseGroupByAgg {
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> {
if let Operator::Aggregate(op) = graph.operator(node_id).clone() {
// if it is an aggregation operator containing agg_call
if !op.agg_calls.is_empty() {
return Ok(());
}

if let Some(Operator::Aggregate(child_op)) = graph
.eldest_child_at(node_id)
.and_then(|child_id| Some(graph.operator_mut(child_id)))
{
if op.groupby_exprs.len() != child_op.groupby_exprs.len() {
return Ok(());
}
let mut expr_set = HashSet::new();

for expr in op.groupby_exprs.iter() {
expr_set.insert(expr);
}
for expr in child_op.groupby_exprs.iter() {
expr_set.remove(expr);
}
if expr_set.len() == 0 {
graph.remove_node(node_id, false);
}
}
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use crate::binder::test::select_sql_run;
Expand Down Expand Up @@ -181,4 +238,26 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_collapse_group_by_agg() -> Result<(), DatabaseError> {
let plan = select_sql_run("select distinct c1, c2 from t1 group by c1, c2").await?;

let optimizer = HepOptimizer::new(plan.clone()).batch(
"test_collapse_group_by_agg".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::CollapseGroupByAgg],
);

let best_plan = optimizer.find_best::<KipTransaction>(None)?;

if let Operator::Aggregate(_) = &best_plan.childrens[0].operator {
if let Operator::Aggregate(_) = &best_plan.childrens[0].childrens[0].operator {
unreachable!("Should not be a agg operator")
} else {
return Ok(());
}
}
unreachable!("Should be a agg operator")
}
}
7 changes: 6 additions & 1 deletion src/optimizer/rule/normalization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::optimizer::core::pattern::Pattern;
use crate::optimizer::core::rule::{MatchPattern, NormalizationRule};
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::optimizer::rule::normalization::column_pruning::ColumnPruning;
use crate::optimizer::rule::normalization::combine_operators::{CollapseProject, CombineFilter};
use crate::optimizer::rule::normalization::combine_operators::{
CollapseGroupByAgg, CollapseProject, CombineFilter,
};
use crate::optimizer::rule::normalization::pushdown_limit::{
EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin,
};
Expand All @@ -24,6 +26,7 @@ pub enum NormalizationRuleImpl {
ColumnPruning,
// Combine operators
CollapseProject,
CollapseGroupByAgg,
CombineFilter,
// PushDown limit
LimitProjectTranspose,
Expand All @@ -44,6 +47,7 @@ impl MatchPattern for NormalizationRuleImpl {
match self {
NormalizationRuleImpl::ColumnPruning => ColumnPruning.pattern(),
NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(),
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(),
NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(),
NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(),
NormalizationRuleImpl::EliminateLimits => EliminateLimits.pattern(),
Expand All @@ -62,6 +66,7 @@ impl NormalizationRule for NormalizationRuleImpl {
match self {
NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(node_id, graph),
NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph),
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph),
NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph),
NormalizationRuleImpl::LimitProjectTranspose => {
LimitProjectTranspose.apply(node_id, graph)
Expand Down
20 changes: 10 additions & 10 deletions tests/slt/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ statement ok
insert into t values (0,1,1), (1,2,1), (2,3,2), (3,4,2), (4,5,3)

# TODO: check on binder
# statement error
# select v2 + 1, v1 from t group by v2 + 1
statement error
select v2 + 1, v1 from t group by v2 + 1

# statement error
# select v2 + 1 as a, v1 as b from t group by a
statement error
select v2 + 1 as a, v1 as b from t group by a

# statement error
# select v2, v2 + 1, sum(v1) from t group by v2 + 1
statement error
select v2, v2 + 1, sum(v1) from t group by v2 + 1

# statement error
# select v2 + 2 + count(*) from t group by v2 + 1
statement error
select v2 + 2 + count(*) from t group by v2 + 1

# statement error
# select v2 + count(*) from t group by v2 order by v1;
statement error
select v2 + count(*) from t group by v2 order by v1;

query II rowsort
select v2 + 1, sum(v1) from t group by v2 + 1
Expand Down

0 comments on commit bb9ccef

Please sign in to comment.