From 4bed04e4e312a0b125306944aee94a93c2ff6c4f Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 11 Jul 2024 19:26:46 +0300 Subject: [PATCH] Add customizable equality and hash functions to UDFs (#11392) * Add customizable equality and hash functions to UDFs * Improve equals and hash_value documentation * Add tests for parameterized UDFs --- .../user_defined/user_defined_aggregates.rs | 79 ++++++++++- .../user_defined_scalar_functions.rs | 128 +++++++++++++++++- datafusion/expr/src/udaf.rs | 73 ++++++++-- datafusion/expr/src/udf.rs | 62 +++++++-- datafusion/expr/src/udwf.rs | 69 ++++++++-- 5 files changed, 367 insertions(+), 44 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index d591c662d877..96de865b6554 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,14 +18,19 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::{ + types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray, +}; +use arrow_schema::Schema; + +use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; use datafusion::test_util::plan_and_collect; use datafusion::{ @@ -45,8 +50,8 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, + LogicalPlanBuilder, SimpleAggregateUDF, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_parameterized_aggregate_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + let udf1 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 1, + }); + let udf2 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 2, + }); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .aggregate( + [col("text")], + [ + udf1.call(vec![col("text")]).alias("a"), + udf2.call(vec![col("text")]).alias("b"), + ], + )? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+------+---+---+", + "| text | a | b |", + "+------+---+---+", + "| foo | 1 | 2 |", + "+------+---+---+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator { ) -> Result> { Ok(Box::new(self.clone())) } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.result == other.result && self.signature == other.signature + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.signature.hash(hasher); + self.result.hash(hasher); + hasher.finish() + } } impl Accumulator for TestGroupsAccumulator { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1733068debb9..5847952ae6a6 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,11 +16,20 @@ // under the License. use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use arrow::compute::kernels::numeric::add; -use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; +use arrow_array::builder::BooleanBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, +}; use arrow_schema::{DataType, Field, Schema}; +use parking_lot::Mutex; +use regex::Regex; +use sqlparser::ast::Ident; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -37,8 +46,6 @@ use datafusion_expr::{ Volatility, }; use datafusion_functions_array::range::range_udf; -use parking_lot::Mutex; -use sqlparser::ast::Ident; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -1021,6 +1028,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( Ok(()) } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + regex: Regex, +} + +impl MyRegexUdf { + fn new(pattern: &str) -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + regex: Regex::new(pattern).expect("regex"), + } + } + + fn matches(&self, value: Option<&str>) -> Option { + Some(self.regex.is_match(value?)) + } +} + +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regex_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Boolean) + } else { + plan_err!("regex_udf only accepts a Utf8 argument") + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + self.matches(value.as_deref()), + ))) + } + [ColumnarValue::Array(values)] => { + let mut builder = BooleanBuilder::with_capacity(values.len()); + for value in values.as_string::() { + builder.append_option(self.matches(value)) + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.regex.as_str() == other.regex.as_str() + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.regex.as_str().hash(hasher); + hasher.finish() + } +} + +#[tokio::test] +async fn test_parameterized_scalar_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}")); + let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar")); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .filter( + foo_udf + .call(vec![col("text")]) + .and(bar_udf.call(vec![col("text")])), + )? + .filter(col("text").is_not_null())? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+--------+", + "| text |", + "+--------+", + "| foobar |", + "| barfoo |", + "+--------+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7a054abea75b..1657e034fbe2 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,17 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; +use std::vec; + +use arrow::datatypes::{DataType, Field}; +use sqlparser::ast::NullTreatment; + +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; + use crate::expr::AggregateFunction; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, @@ -26,13 +37,6 @@ use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; -use sqlparser::ast::NullTreatment; -use std::any::Any; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -72,20 +76,19 @@ pub struct AggregateUDF { impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for AggregateUDF {} -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } -impl std::fmt::Display for AggregateUDF { +impl fmt::Display for AggregateUDF { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{}", self.name()) } @@ -280,7 +283,7 @@ where /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature -/// }; +/// } /// /// impl GeoMeanUdf { /// fn new() -> Self { @@ -507,6 +510,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Return true if this aggregate UDF is equal to the other. + /// + /// Allows customizing the equality of aggregate UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this aggregate UDF. + /// + /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } pub enum ReversedUDAF { @@ -562,6 +592,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 68d3af6ace3c..1fbb3cc584b3 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,8 +19,13 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::datatypes::DataType; + +use datafusion_common::{not_impl_err, ExprSchema, Result}; + use crate::expr::create_name; use crate::interval_arithmetic::Interval; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -29,9 +34,6 @@ use crate::{ ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; -use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, ExprSchema, Result}; - /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This @@ -59,16 +61,15 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for ScalarUDF {} -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -294,7 +295,7 @@ where /// #[derive(Debug)] /// struct AddOne { /// signature: Signature -/// }; +/// } /// /// impl AddOne { /// fn new() -> Self { @@ -540,6 +541,33 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Return true if this scalar UDF is equal to the other. + /// + /// Allows customizing the equality of scalar UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this scalar UDF. + /// + /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -557,7 +585,6 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } } } @@ -586,6 +613,21 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 70b44e5e307a..1a6b21e3dd29 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,18 +17,22 @@ //! [`WindowUDF`]: User Defined Window Functions -use crate::{ - function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; +use arrow::datatypes::DataType; + +use datafusion_common::Result; + +use crate::{ + function::WindowFunctionSimplification, Expr, PartitionEvaluator, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, +}; + /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -62,16 +66,15 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for WindowUDF {} -impl std::hash::Hash for WindowUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -212,7 +215,7 @@ where /// #[derive(Debug, Clone)] /// struct SmoothIt { /// signature: Signature -/// }; +/// } /// /// impl SmoothIt { /// fn new() -> Self { @@ -296,6 +299,33 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Return true if this window UDF is equal to the other. + /// + /// Allows customizing the equality of window UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this window UDF. + /// + /// Allows customizing the hash code of window UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// WindowUDF that adds an alias to the underlying function. It is better to @@ -342,6 +372,21 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers