Skip to content

Commit

Permalink
fix: support min/max for Float16 type
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Aug 18, 2024
1 parent 950dc73 commit 1d8eed2
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ datafusion-expr = { workspace = true }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
half = { workspace = true }
log = { workspace = true }
paste = "1.0.14"
sqlparser = { workspace = true }
Expand Down
34 changes: 25 additions & 9 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@

use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray,
StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute;
use arrow::datatypes::{
DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type,
Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type,
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
};
use arrow_schema::IntervalUnit;
use datafusion_common::{
Expand All @@ -66,6 +67,7 @@ use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use half::f16;
use std::ops::Deref;

fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
Expand Down Expand Up @@ -181,6 +183,7 @@ impl AggregateUDFImpl for Max {
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
Expand Down Expand Up @@ -209,6 +212,9 @@ impl AggregateUDFImpl for Max {
UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_max_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_max_accumulator!(data_type, f32, Float32Type)
}
Expand Down Expand Up @@ -339,6 +345,9 @@ macro_rules! min_max_batch {
DataType::Float32 => {
typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
}
DataType::Float16 => {
typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
}
DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
Expand Down Expand Up @@ -623,6 +632,9 @@ macro_rules! min_max {
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
typed_min_max_float!(lhs, rhs, Float32, $OP)
}
(ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
typed_min_max_float!(lhs, rhs, Float16, $OP)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
typed_min_max!(lhs, rhs, UInt64, $OP)
}
Expand Down Expand Up @@ -950,6 +962,7 @@ impl AggregateUDFImpl for Min {
| UInt16
| UInt32
| UInt64
| Float16
| Float32
| Float64
| Decimal128(_, _)
Expand Down Expand Up @@ -978,6 +991,9 @@ impl AggregateUDFImpl for Min {
UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_min_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_min_accumulator!(data_type, f32, Float32Type)
}
Expand Down
28 changes: 28 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5643,3 +5643,31 @@ query I??III?T
select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(NULL), nth_value(NULL, 1), string_agg(NULL, ',');
----
0 NULL NULL NULL NULL NULL NULL NULL

# test group min/max Float16 without group expression
query RRTT
WITH data AS (
SELECT arrow_cast(1, 'Float16') AS f
UNION ALL
SELECT arrow_cast(6, 'Float16') AS f
)
SELECT MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) FROM data;
----
1 6 Float16 Float16

# test group min/max Float16 with group expression
query IRRTT
WITH data AS (
SELECT 1 as k, arrow_cast(1.8125, 'Float16') AS f
UNION ALL
SELECT 1 as k, arrow_cast(6.8007813, 'Float16') AS f
UNION ALL
SELECT 2 AS k, arrow_cast(8.5, 'Float16') AS f
)
SELECT k, MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f))
FROM data
GROUP BY k
ORDER BY k;
----
1 1.8125 6.8007813 Float16 Float16
2 8.5 8.5 Float16 Float16

0 comments on commit 1d8eed2

Please sign in to comment.