Skip to content

Commit

Permalink
ScalarUDF: Remove supports_zero_argument and avoid creating null ar…
Browse files Browse the repository at this point in the history
…ray for empty args (apache#10193)

* Avoid create null array for empty args

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* return scalar instead of array

Signed-off-by: jayzhan211 <[email protected]>

* remove supports 0 args in scalarudf

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* rm test1

Signed-off-by: jayzhan211 <[email protected]>

* invoke no args and support randomness

Signed-off-by: jayzhan211 <[email protected]>

* rm randomness

Signed-off-by: jayzhan211 <[email protected]>

* add func with no args

Signed-off-by: jayzhan211 <[email protected]>

* array

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Apr 26, 2024
1 parent 6f0c693 commit c9bd291
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 232 deletions.
4 changes: 0 additions & 4 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1448,7 +1447,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -1520,7 +1518,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1589,7 +1586,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
202 changes: 66 additions & 136 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
Expand All @@ -36,9 +33,7 @@ use datafusion_expr::{
Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
use std::sync::Arc;

/// test that casting happens on udfs.
Expand Down Expand Up @@ -168,6 +163,48 @@ async fn scalar_udf() -> Result<()> {
Ok(())
}

struct Simple0ArgsScalarUDF {
name: String,
signature: Signature,
return_type: DataType,
}

impl std::fmt::Debug for Simple0ArgsScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl ScalarUDFImpl for Simple0ArgsScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!("{} function does not accept arguments", self.name())
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}
}

#[tokio::test]
async fn scalar_udf_zero_params() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -179,20 +216,14 @@ async fn scalar_udf_zero_params() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = Arc::new(|_args: &[ColumnarValue]| {
Ok(ColumnarValue::Array(
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
))
});

ctx.register_udf(create_udf(
"get_100",
vec![],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
));
let get_100_udf = Simple0ArgsScalarUDF {
name: "get_100".to_string(),
signature: Signature::exact(vec![], Volatility::Immutable),
return_type: DataType::Int32,
};

ctx.register_udf(ScalarUDF::from(get_100_udf));

let result = plan_and_collect(&ctx, "select get_100() a from t").await?;
let expected = [
Expand Down Expand Up @@ -403,123 +434,6 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

#[derive(Debug)]
pub struct RandomUDF {
signature: Signature,
}

impl RandomUDF {
pub fn new() -> Self {
Self {
signature: Signature::any(0, Volatility::Volatile),
}
}
}

impl ScalarUDFImpl for RandomUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"random_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
// This udf is always invoked with zero argument so its argument
// is a null array indicating the batch size.
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
_ => {
return Err(datafusion::error::DataFusionError::Internal(
"Invalid argument type".to_string(),
))
}
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
}

/// Ensure that a user defined function with zero argument will be invoked
/// with a null array indicating the batch size.
#[tokio::test]
async fn test_user_defined_functions_zero_argument() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new(
"index",
DataType::UInt8,
false,
)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
)?;

ctx.register_batch("data_table", batch)?;

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
ctx.register_udf(random_normal_udf);

let result = plan_and_collect(
&ctx,
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
)
.await?;

assert_eq!(result.len(), 1);
let batch = &result[0];
let random_udf = batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let native_random = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();

assert_eq!(random_udf.len(), native_random.len());

let mut previous = -1.0;
for i in 0..random_udf.len() {
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
assert!(random_udf.value(i) != previous);
previous = random_udf.value(i);
}

Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
let random_normal_udf = ScalarUDF::from(RandomUDF::new());
let ctx = SessionContext::new();

ctx.register_udf(random_normal_udf.clone());

assert!(ctx.udfs().contains("random_udf"));

ctx.deregister_udf("random_udf");

assert!(!ctx.udfs().contains("random_udf"));

Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
Expand Down Expand Up @@ -615,6 +529,22 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
let cast2i64 = ScalarUDF::from(CastToI64UDF::new());
let ctx = SessionContext::new();

ctx.register_udf(cast2i64.clone());

assert!(ctx.udfs().contains("cast_to_i64"));

ctx.deregister_udf("cast_to_i64");

assert!(!ctx.udfs().contains("cast_to_i64"));

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub fn data_types(
);
}
}

let valid_types = get_valid_types(&signature.type_signature, current_types)?;

if valid_types
Expand Down
29 changes: 23 additions & 6 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
ScalarFunctionImplementation, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::{ExprSchema, Result};
use datafusion_common::{not_impl_err, ExprSchema, Result};
use std::any::Any;
use std::fmt;
use std::fmt::Debug;
Expand Down Expand Up @@ -180,6 +180,13 @@ impl ScalarUDF {
self.inner.invoke(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
self.inner.invoke_no_args(number_rows)
}

/// Returns a `ScalarFunctionImplementation` that can invoke the function
/// during execution
pub fn fun(&self) -> ScalarFunctionImplementation {
Expand Down Expand Up @@ -322,10 +329,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// # Zero Argument Functions
/// If the function has zero parameters (e.g. `now()`) it will be passed a
/// single element slice which is a a null array to indicate the batch's row
/// count (so the function can know the resulting array size).
/// If the function does not take any arguments, please use [invoke_no_args]
/// instead and return [not_impl_err] for this function.
///
///
/// # Performance
///
Expand All @@ -335,7 +341,18 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue>;
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke_no_args but called",
self.name()
)
}

/// Returns any aliases (alternate names) for this function.
///
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions-array/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ impl ScalarUDFImpl for MakeArray {
make_scalar_function(make_array_inner)(args)
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
make_scalar_function(make_array_inner)(&[])
}

fn aliases(&self) -> &[String] {
&self.aliases
}
Expand Down
18 changes: 9 additions & 9 deletions datafusion/functions/src/math/pi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::Float64Array;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Float64;

use datafusion_common::{exec_err, Result};
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

Expand Down Expand Up @@ -62,12 +60,14 @@ impl ScalarUDFImpl for PiFunc {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if !matches!(&args[0], ColumnarValue::Array(_)) {
return exec_err!("Expect pi function to take no param");
}
let array = Float64Array::from_value(std::f64::consts::PI, 1);
Ok(ColumnarValue::Array(Arc::new(array)))
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!("{} function does not accept arguments", self.name())
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
std::f64::consts::PI,
))))
}

fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> {
Expand Down
Loading

0 comments on commit c9bd291

Please sign in to comment.