Skip to content

Commit

Permalink
sql: ignore NULLs in aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
erikgrinaker committed Jul 14, 2024
1 parent 68f0734 commit f408a41
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 67 deletions.
32 changes: 15 additions & 17 deletions src/sql/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,40 +107,38 @@ impl Accumulator {
}

/// Adds a value to the accumulator.
/// TODO: NULL values should possibly be ignored, not yield NULL (see Postgres?).
fn add(&mut self, value: Value) -> Result<()> {
use std::cmp::Ordering;
match (self, value) {
(Self::Average { sum: Value::Null, count: _ }, _) => {}
(Self::Average { sum, count: _ }, Value::Null) => *sum = Value::Null,
(Self::Average { sum, count }, value) => {

// NULL values are ignored in aggregates.
if value == Value::Null {
return Ok(());
}

match self {
Self::Average { sum, count } => {
*sum = sum.checked_add(&value)?;
*count += 1;
}

(Self::Count(_), Value::Null) => {}
(Self::Count(c), _) => *c += 1,
Self::Count(c) => *c += 1,

(Self::Max(Some(Value::Null)), _) => {}
(Self::Max(max), Value::Null) => *max = Some(Value::Null),
(Self::Max(max @ None), value) => *max = Some(value),
(Self::Max(Some(max)), value) => {
Self::Max(max @ None) => *max = Some(value),
Self::Max(Some(max)) => {
if value.cmp(max) == Ordering::Greater {
*max = value
}
}

(Self::Min(Some(Value::Null)), _) => {}
(Self::Min(min), Value::Null) => *min = Some(Value::Null),
(Self::Min(min @ None), value) => *min = Some(value),
(Self::Min(Some(min)), value) => {
Self::Min(min @ None) => *min = Some(value),
Self::Min(Some(min)) => {
if value.cmp(min) == Ordering::Less {
*min = value
}
}

(Self::Sum(sum @ None), value) => *sum = Some(Value::Integer(0).checked_add(&value)?),
(Self::Sum(Some(sum)), value) => *sum = sum.checked_add(&value)?,
Self::Sum(sum @ None) => *sum = Some(Value::Integer(0).checked_add(&value)?),
Self::Sum(Some(sum)) => *sum = sum.checked_add(&value)?,
}
Ok(())
}
Expand Down
52 changes: 5 additions & 47 deletions src/sql/testscripts/queries/aggregate
Original file line number Diff line number Diff line change
Expand Up @@ -78,29 +78,21 @@ Projection: #0
5

> SELECT MAX("bool") FROM test
> SELECT MAX("bool") FROM test WHERE "bool" IS NOT NULL
---
NULL
TRUE

> SELECT MAX("int") FROM test
> SELECT MAX("int") FROM test WHERE "int" IS NOT NULL
---
NULL
42

> SELECT MAX("float") FROM test
> SELECT MAX("float") FROM test WHERE "float" IS NOT NULL
> SELECT MAX("float") FROM test WHERE "float" IS NOT NAN AND "float" IS NOT NULL
> SELECT MAX("float") FROM test WHERE "float" IS NOT NAN
---
NULL
NaN
inf

> SELECT MAX("string") FROM test
> SELECT MAX("string") FROM test WHERE "string" IS NOT NULL
---
NULL
👋

# MIN works on constant values.
Expand Down Expand Up @@ -128,27 +120,19 @@ Projection: #0
0

> SELECT MIN("bool") FROM test
> SELECT MIN("bool") FROM test WHERE "bool" IS NOT NULL
---
NULL
FALSE

> SELECT MIN("int") FROM test
> SELECT MIN("int") FROM test WHERE "int" IS NOT NULL
---
NULL
-1

> SELECT MIN("float") FROM test
> SELECT MIN("float") FROM test WHERE "float" IS NOT NULL
---
NULL
0

> SELECT MIN("string") FROM test
> SELECT MIN("string") FROM test WHERE "string" IS NOT NULL
---
> NULL
>

# SUM works on constant values, but only numbers.
Expand Down Expand Up @@ -183,34 +167,22 @@ Projection: #0
15

!> SELECT SUM("bool") FROM test
!> SELECT SUM("bool") FROM test WHERE "bool" IS NOT NULL
> SELECT SUM("bool") FROM test WHERE "bool" IS NULL
---
Error: invalid input: can't add NULL and TRUE
Error: invalid input: can't add 0 and TRUE
NULL

> SELECT SUM("int") FROM test
> SELECT SUM("int") FROM test WHERE "int" IS NOT NULL
---
NULL
44

> SELECT SUM("float") FROM test
> SELECT SUM("float") FROM test WHERE "float" IS NOT NULL
> SELECT SUM("float") FROM test WHERE "float" IS NOT NAN AND "float" IS NOT NULL
> SELECT SUM("float") FROM test WHERE "float" IS NOT NAN
---
NULL
NaN
inf

!> SELECT SUM("string") FROM test
!> SELECT SUM("string") FROM test WHERE "string" IS NOT NULL
> SELECT SUM("string") FROM test WHERE "string" IS NULL
---
Error: invalid input: can't add NULL and
Error: invalid input: can't add 0 and
NULL

# AVG works on constant values, but only numbers.
[plan]> SELECT AVG(NULL), AVG(1), AVG(3.14), AVG(NAN) FROM test
Expand Down Expand Up @@ -243,37 +215,23 @@ Projection: #0
└─ Scan: test
2

# TODO: the first case here should error.
> SELECT AVG("bool") FROM test
!> SELECT AVG("bool") FROM test WHERE "bool" IS NOT NULL
> SELECT AVG("bool") FROM test WHERE "bool" IS NULL
!> SELECT AVG("bool") FROM test
---
NULL
Error: invalid input: can't add 0 and TRUE
NULL

> SELECT AVG("int") FROM test
> SELECT AVG("int") FROM test WHERE "int" IS NOT NULL
---
NULL
11

> SELECT AVG("float") FROM test
> SELECT AVG("float") FROM test WHERE "float" IS NOT NULL
> SELECT AVG("float") FROM test WHERE "float" IS NOT NAN AND "float" IS NOT NULL
> SELECT AVG("float") FROM test WHERE "float" IS NOT NAN
---
NULL
NaN
inf

# TODO: the first case here should error.
> SELECT AVG("string") FROM test
!> SELECT AVG("string") FROM test WHERE "string" IS NOT NULL
> SELECT AVG("string") FROM test WHERE "string" IS NULL
!> SELECT AVG("string") FROM test
---
NULL
Error: invalid input: can't add 0 and
NULL

# Constant aggregates can be used with rows.
[plan]> SELECT COUNT(1), MIN(1), MAX(1), SUM(1), AVG(1) FROM test
Expand Down
6 changes: 3 additions & 3 deletions src/sql/testscripts/queries/order
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ Projection: #0, #1
└─ Projection: bool, #0, #0
└─ Aggregate: max(int) group by bool
└─ Scan: test
NULL, 1000
TRUE, 0
FALSE, -1
NULL, NULL

[plan]> SELECT "bool" FROM test GROUP BY "bool" ORDER BY MAX("int") DESC
---
Expand All @@ -488,9 +488,9 @@ Projection: #0
└─ Projection: bool, #0
└─ Aggregate: max(int) group by bool
└─ Scan: test
NULL
TRUE
FALSE
NULL

[plan]> SELECT "bool", MAX("int") FROM test GROUP BY "bool" ORDER BY MAX("int") - MIN("int") DESC
---
Expand All @@ -499,9 +499,9 @@ Projection: #0, #1
└─ Projection: bool, #0, #0, #1
└─ Aggregate: max(int), min(int) group by bool
└─ Scan: test
NULL, 1000
FALSE, -1
TRUE, 0
NULL, NULL

# ORDER BY works with compound expressions using complex GROUP BY expressions
# that are not on the SELECT clause.
Expand Down

0 comments on commit f408a41

Please sign in to comment.