From 1d8eed214b345072bc85f2e25abd07bf9916050d Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sun, 18 Aug 2024 14:43:00 +0300 Subject: [PATCH] fix: support min/max for Float16 type --- datafusion-cli/Cargo.lock | 1 + datafusion/functions-aggregate/Cargo.toml | 1 + datafusion/functions-aggregate/src/min_max.rs | 34 ++++++++++++++----- .../sqllogictest/test_files/aggregate.slt | 28 +++++++++++++++ 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 52e4a000355d7..b5637f785fb2d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1340,6 +1340,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", + "half", "log", "paste", "sqlparser", diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 636b2e42d236c..d78f68a2604e7 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -47,6 +47,7 @@ datafusion-expr = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +half = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 4dcd5ac0e9515..961e8639604c8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -34,18 +34,19 @@ use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray, - StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::compute; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, }; use arrow_schema::IntervalUnit; use datafusion_common::{ @@ -66,6 +67,7 @@ use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use half::f16; use std::ops::Deref; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -181,6 +183,7 @@ impl AggregateUDFImpl for Max { | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal128(_, _) @@ -209,6 +212,9 @@ impl AggregateUDFImpl for Max { UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), + Float16 => { + instantiate_max_accumulator!(data_type, f16, Float16Type) + } Float32 => { instantiate_max_accumulator!(data_type, f32, Float32Type) } @@ -339,6 +345,9 @@ macro_rules! min_max_batch { DataType::Float32 => { typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) } + DataType::Float16 => { + typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) + } DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), @@ -623,6 +632,9 @@ macro_rules! min_max { (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { typed_min_max_float!(lhs, rhs, Float32, $OP) } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + typed_min_max_float!(lhs, rhs, Float16, $OP) + } (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { typed_min_max!(lhs, rhs, UInt64, $OP) } @@ -950,6 +962,7 @@ impl AggregateUDFImpl for Min { | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal128(_, _) @@ -978,6 +991,9 @@ impl AggregateUDFImpl for Min { UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), + Float16 => { + instantiate_min_accumulator!(data_type, f16, Float16Type) + } Float32 => { instantiate_min_accumulator!(data_type, f32, Float32Type) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0cda24d6ff5e4..442d3d835be1c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5643,3 +5643,31 @@ query I??III?T select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(NULL), nth_value(NULL, 1), string_agg(NULL, ','); ---- 0 NULL NULL NULL NULL NULL NULL NULL + +# test group min/max Float16 without group expression +query RRTT +WITH data AS ( + SELECT arrow_cast(1, 'Float16') AS f + UNION ALL + SELECT arrow_cast(6, 'Float16') AS f +) +SELECT MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) FROM data; +---- +1 6 Float16 Float16 + +# test group min/max Float16 with group expression +query IRRTT +WITH data AS ( + SELECT 1 as k, arrow_cast(1.8125, 'Float16') AS f + UNION ALL + SELECT 1 as k, arrow_cast(6.8007813, 'Float16') AS f + UNION ALL + SELECT 2 AS k, arrow_cast(8.5, 'Float16') AS f +) +SELECT k, MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) +FROM data +GROUP BY k +ORDER BY k; +---- +1 1.8125 6.8007813 Float16 Float16 +2 8.5 8.5 Float16 Float16