From 8730466c9b84c7ecb97a240898d25256184083bb Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 16 Apr 2024 09:53:01 -0400 Subject: [PATCH] Move concat, concat_ws, ends_with, initcap to datafusion-functions (#10089) --- datafusion/core/src/execution/context/mod.rs | 5 +- .../tests/dataframe/dataframe_functions.rs | 2 +- .../core/tests/optimizer_integration.rs | 28 +- datafusion/core/tests/simplification.rs | 99 +++- datafusion/expr/src/built_in_function.rs | 90 +--- datafusion/expr/src/expr_fn.rs | 78 +-- datafusion/functions/Cargo.toml | 5 + .../benches/concat.rs | 4 +- datafusion/functions/src/string/common.rs | 95 +++- datafusion/functions/src/string/concat.rs | 262 +++++++++ datafusion/functions/src/string/concat_ws.rs | 423 +++++++++++++++ datafusion/functions/src/string/ends_with.rs | 161 ++++++ datafusion/functions/src/string/initcap.rs | 159 ++++++ datafusion/functions/src/string/mod.rs | 34 ++ .../optimizer/src/analyzer/type_coercion.rs | 58 +- .../simplify_expressions/expr_simplifier.rs | 143 +---- .../src/simplify_expressions/utils.rs | 123 +---- .../optimizer/tests/optimizer_integration.rs | 26 - datafusion/physical-expr/Cargo.toml | 4 - datafusion/physical-expr/src/functions.rs | 312 +++-------- datafusion/physical-expr/src/lib.rs | 1 - datafusion/physical-expr/src/planner.rs | 2 +- .../physical-expr/src/string_expressions.rs | 495 ------------------ datafusion/physical-expr/src/utils/mod.rs | 2 +- datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 12 - datafusion/proto/src/generated/prost.rs | 20 +- .../proto/src/logical_plan/from_proto.rs | 20 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/from_proto.rs | 2 +- datafusion/sql/tests/sql_integration.rs | 3 +- 31 files changed, 1409 insertions(+), 1271 deletions(-) rename datafusion/{physical-expr => functions}/benches/concat.rs (93%) create mode 100644 datafusion/functions/src/string/concat.rs create mode 100644 datafusion/functions/src/string/concat_ws.rs create mode 100644 datafusion/functions/src/string/ends_with.rs create mode 100644 datafusion/functions/src/string/initcap.rs delete mode 100644 datafusion/physical-expr/src/string_expressions.rs diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index fc2cdbb7518d..9b5a5fef8cb9 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -58,7 +58,10 @@ use crate::{ physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, variable::{VarProvider, VarType}, }; -use crate::{functions, functions_aggregate, functions_array}; + +#[cfg(feature = "array_expressions")] +use crate::functions_array; +use crate::{functions, functions_aggregate}; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 4371cce856ce..c97735ce9cf1 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -37,7 +37,7 @@ use datafusion::assert_batches_eq; use datafusion_common::DFSchema; use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast, ExprSchemable}; -use datafusion_functions::unicode::expr_fn::character_length; +use datafusion_functions_array::expr_fn::array_to_string; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 6e938361ddb4..5a7870b7a01c 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -83,6 +83,30 @@ fn timestamp_nano_ts_utc_predicates() { assert_eq!(expected, format!("{plan:?}")); } +#[test] +fn concat_literals() -> Result<()> { + let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ + AS col + FROM test"; + let expected = + "Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ + \n TableScan: test projection=[col_int32, col_utf8]"; + quick_test(sql, expected); + Ok(()) +} + +#[test] +fn concat_ws_literals() -> Result<()> { + let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \ + AS col + FROM test"; + let expected = + "Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ + \n TableScan: test projection=[col_int32, col_utf8]"; + quick_test(sql, expected); + Ok(()) +} + fn quick_test(sql: &str, expected_plan: &str) { let plan = test_sql(sql).unwrap(); assert_eq!(expected_plan, format!("{:?}", plan)); @@ -97,7 +121,9 @@ fn test_sql(sql: &str) -> Result { // create a logical query plan let context_provider = MyContextProvider::default() .with_udf(datetime::now()) - .with_udf(datafusion_functions::core::arrow_cast()); + .with_udf(datafusion_functions::core::arrow_cast()) + .with_udf(datafusion_functions::string::concat()) + .with_udf(datafusion_functions::string::concat_ws()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index dc075e669564..c5ce5d2652e0 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility, }; -use datafusion_functions::math; +use datafusion_functions::{math, string}; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; @@ -217,7 +217,7 @@ fn fold_and_simplify() { let info: MyInfo = schema().into(); // What will it do with the expression `concat('foo', 'bar') == 'foobar')`? - let expr = concat(&[lit("foo"), lit("bar")]).eq(lit("foobar")); + let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar")); // Since datafusion applies both simplification *and* rewriting // some expressions can be entirely simplified @@ -364,13 +364,13 @@ fn test_const_evaluator() { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap(); + let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap(); - let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap(); + let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]); + let expr = string::expr_fn::concat(vec![lit("foo"), concat1]); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments @@ -569,3 +569,92 @@ fn test_simplify_power() { test_simplify(expr, expected) } } + +#[test] +fn test_simplify_concat_ws() { + let null = lit(ScalarValue::Utf8(None)); + // the delimiter is not a literal + { + let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]); + let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]); + test_simplify(expr, expected); + } + + // the delimiter is an empty string + { + let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]); + let expected = concat(vec![col("a"), lit("cb")]); + test_simplify(expr, expected); + } + + // the delimiter is a not-empty string + { + let expr = concat_ws( + lit("-"), + vec![ + null.clone(), + col("c0"), + lit("hello"), + null.clone(), + lit("rust"), + col("c1"), + lit(""), + lit(""), + null, + ], + ); + let expected = concat_ws( + lit("-"), + vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")], + ); + test_simplify(expr, expected) + } +} + +#[test] +fn test_simplify_concat_ws_with_null() { + let null = lit(ScalarValue::Utf8(None)); + // null delimiter -> null + { + let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); + test_simplify(expr, null.clone()); + } + + // filter out null args + { + let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]); + let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]); + test_simplify(expr, expected); + } + + // nested test + { + let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); + let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]); + test_simplify(expr, concat_ws(lit("|"), vec![col("c3")])); + } + + // null delimiter (nested) + { + let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); + let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); + test_simplify(expr, null); + } +} + +#[test] +fn test_simplify_concat() { + let null = lit(ScalarValue::Utf8(None)); + let expr = concat(vec![ + null.clone(), + col("c0"), + lit("hello "), + null.clone(), + lit("rust"), + col("c1"), + lit(""), + null, + ]); + let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); + test_simplify(expr, expected) +} diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 7ec544a57edb..83eb2f722b08 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -23,7 +23,7 @@ use std::str::FromStr; use std::sync::OnceLock; use crate::type_coercion::functions::data_types; -use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility}; +use crate::{FuncMonotonicity, Signature, Volatility}; use arrow::datatypes::DataType; use datafusion_common::{plan_err, DataFusionError, Result}; @@ -39,15 +39,6 @@ pub enum BuiltinScalarFunction { // math functions /// coalesce Coalesce, - // string functions - /// concat - Concat, - /// concat_ws - ConcatWithSeparator, - /// ends_with - EndsWith, - /// initcap - InitCap, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -101,10 +92,6 @@ impl BuiltinScalarFunction { match self { // Immutable scalar builtins BuiltinScalarFunction::Coalesce => Volatility::Immutable, - BuiltinScalarFunction::Concat => Volatility::Immutable, - BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, - BuiltinScalarFunction::EndsWith => Volatility::Immutable, - BuiltinScalarFunction::InitCap => Volatility::Immutable, } } @@ -117,8 +104,6 @@ impl BuiltinScalarFunction { /// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation. /// 2. Deduce the output `DataType` based on the provided `input_expr_types`. pub fn return_type(self, input_expr_types: &[DataType]) -> Result { - use DataType::*; - // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -130,43 +115,18 @@ impl BuiltinScalarFunction { let coerced_types = data_types(input_expr_types, &self.signature()); coerced_types.map(|types| types[0].clone()) } - BuiltinScalarFunction::Concat => Ok(Utf8), - BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), - BuiltinScalarFunction::InitCap => { - utf8_to_str_type(&input_expr_types[0], "initcap") - } - BuiltinScalarFunction::EndsWith => Ok(Boolean), } } /// Return the argument [`Signature`] supported by this function pub fn signature(&self) -> Signature { - use DataType::*; - use TypeSignature::*; // note: the physical expression must accept the type returned by this function or the execution panics. // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::Concat - | BuiltinScalarFunction::ConcatWithSeparator => { - Signature::variadic(vec![Utf8], self.volatility()) - } BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::InitCap => { - Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) - } - - BuiltinScalarFunction::EndsWith => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ), } } @@ -182,11 +142,6 @@ impl BuiltinScalarFunction { match self { // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], - - BuiltinScalarFunction::Concat => &["concat"], - BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::EndsWith => &["ends_with"], - BuiltinScalarFunction::InitCap => &["initcap"], } } } @@ -208,49 +163,6 @@ impl FromStr for BuiltinScalarFunction { } } -/// Creates a function to identify the optimal return type of a string function given -/// the type of its first argument. -/// -/// If the input type is `LargeUtf8` or `LargeBinary` the return type is -/// `$largeUtf8Type`, -/// -/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, -macro_rules! get_optimal_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - // LargeBinary inputs are automatically coerced to Utf8 - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - // Binary inputs are automatically coerced to Utf8 - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - _ => { - return plan_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - **value_type - ); - } - }, - data_type => { - return plan_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - data_type - ); - } - }) - } - }; -} - -// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. -get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f7900f6b197d..567f260daaf9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,8 +26,8 @@ use crate::function::{ }; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, - logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, - Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, + logical_plan::Subquery, AggregateUDF, Expr, LogicalPlan, Operator, + ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::{DataType, Field}; @@ -277,26 +277,6 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) } -/// Concatenates the text representations of all the arguments. NULL arguments are ignored. -pub fn concat(args: &[Expr]) -> Expr { - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Concat, - args.to_vec(), - )) -} - -/// Concatenates all but the first argument, with separators. -/// The first argument is used as the separator. -/// NULL arguments in `values` are ignored. -pub fn concat_ws(sep: Expr, values: Vec) -> Expr { - let mut args = values; - args.insert(0, sep); - Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ConcatWithSeparator, - args, - )) -} - /// Returns the approximate number of distinct input values. /// This function provides an approximation of count(DISTINCT x). /// Zero is returned if all input values are null. @@ -498,18 +478,6 @@ pub fn is_not_unknown(expr: Expr) -> Expr { Expr::IsNotUnknown(Box::new(expr)) } -macro_rules! scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident)*, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC($($arg: Expr),*) -> Expr { - Expr::ScalarFunction(ScalarFunction::new( - built_in_function::BuiltinScalarFunction::$ENUM, - vec![$($arg),*], - )) - } - }; -} - macro_rules! nary_scalar_expr { ($ENUM:ident, $FUNC:ident, $DOC:expr) => { #[doc = $DOC ] @@ -525,16 +493,7 @@ macro_rules! nary_scalar_expr { // generate methods for creating the supported unary/binary expressions // math functions -scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); -scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); -//there is a func concat_ws before, so use concat_ws_expr as name.c -nary_scalar_expr!( - ConcatWithSeparator, - concat_ws_expr, - "concatenates several strings, placing a seperator between each one" -); -nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -843,18 +802,9 @@ impl WindowUDFImpl for SimpleWindowUDF { } } -/// Calls a named built in function -pub fn call_fn(name: impl AsRef, args: Vec) -> Result { - match name.as_ref().parse::() { - Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))), - Err(e) => Err(e), - } -} - #[cfg(test)] mod test { use super::*; - use crate::ScalarFunctionDefinition; #[test] fn filter_is_null_and_is_not_null() { @@ -866,28 +816,4 @@ mod test { "col2 IS NOT NULL" ); } - - macro_rules! test_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - $( - col(stringify!($arg.to_string())) - ),* - ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; -} - - #[test] - fn scalar_function_definitions() { - test_scalar_expr!(InitCap, initcap, string); - test_scalar_expr!(EndsWith, ends_with, string, characters); - } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index cf15b490b69f..f9985069413b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -89,6 +89,11 @@ rand = { workspace = true } rstest = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "sync"] } +[[bench]] +harness = false +name = "concat" +required-features = ["string_expressions"] + [[bench]] harness = false name = "to_timestamp" diff --git a/datafusion/physical-expr/benches/concat.rs b/datafusion/functions/benches/concat.rs similarity index 93% rename from datafusion/physical-expr/benches/concat.rs rename to datafusion/functions/benches/concat.rs index cdd54d767f1f..e7b00a6d540a 100644 --- a/datafusion/physical-expr/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -19,7 +19,7 @@ use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr::string_expressions::concat; +use datafusion_functions::string::concat; use std::sync::Arc; fn create_args(size: usize, str_len: usize) -> Vec { @@ -37,7 +37,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args(size, 32); let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { - b.iter(|| criterion::black_box(concat(&args).unwrap())) + b.iter(|| criterion::black_box(concat().invoke(&args).unwrap())) }); group.finish(); } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 97f9e1d93be5..d36bd5cecc47 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -19,10 +19,10 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::{ - new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder, - OffsetSizeTrait, + new_null_array, Array, ArrayDataBuilder, ArrayRef, GenericStringArray, + GenericStringBuilder, OffsetSizeTrait, StringArray, }; -use arrow::buffer::Buffer; +use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; @@ -155,6 +155,95 @@ where } } +pub(crate) enum ColumnarValueRef<'a> { + Scalar(&'a [u8]), + NullableArray(&'a StringArray), + NonNullableArray(&'a StringArray), +} + +impl<'a> ColumnarValueRef<'a> { + #[inline] + pub fn is_valid(&self, i: usize) -> bool { + match &self { + Self::Scalar(_) | Self::NonNullableArray(_) => true, + Self::NullableArray(array) => array.is_valid(i), + } + } + + #[inline] + pub fn nulls(&self) -> Option { + match &self { + Self::Scalar(_) | Self::NonNullableArray(_) => None, + Self::NullableArray(array) => array.nulls().cloned(), + } + } +} + +/// Optimized version of the StringBuilder in Arrow that: +/// 1. Precalculating the expected length of the result, avoiding reallocations. +/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` +pub(crate) struct StringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl StringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = MutableBuffer::with_capacity( + (item_capacity + 1) * std::mem::size_of::(), + ); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i32) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i32 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + pub fn finish(self, null_buffer: Option) -> StringArray { + let array_builder = ArrayDataBuilder::new(DataType::Utf8) + .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + StringArray::from(array_data) + } +} + fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs new file mode 100644 index 000000000000..55b7c2f22249 --- /dev/null +++ b/datafusion/functions/src/string/concat.rs @@ -0,0 +1,262 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Utf8; + +use datafusion_common::cast::as_string_array; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; +use crate::string::concat; + +#[derive(Debug)] +pub struct ConcatFunc { + signature: Signature, +} + +impl Default for ConcatFunc { + fn default() -> Self { + ConcatFunc::new() + } +} + +impl ConcatFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ConcatFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "concat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + /// Concatenates the text representations of all the arguments. NULL arguments are ignored. + /// concat('abcde', 2, NULL, 22) = 'abcde222' + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let array_len = args + .iter() + .filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .next(); + + // Scalar + if array_len.is_none() { + let mut result = String::new(); + for arg in args { + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { + result.push_str(v); + } + } + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + } + + // Array + let len = array_len.unwrap(); + let mut data_size = 0; + let mut columns = Vec::with_capacity(args.len()); + + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + } + _ => unreachable!(), + } + } + + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) + } + + /// Simplify the `concat` function by + /// 1. filtering out all `null` literals + /// 2. concatenating contiguous literal arguments + /// + /// For example: + /// `concat(col(a), 'hello ', 'world', col(b), null)` + /// will be optimized to + /// `concat(col(a), 'hello world', col(b))` + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + simplify_concat(args) + } +} + +pub fn simplify_concat(args: Vec) -> Result { + let mut new_args = Vec::with_capacity(args.len()); + let mut contiguous_scalar = "".to_string(); + + for arg in args.clone() { + match arg { + // filter out `null` args + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. + // Concatenate it with the `contiguous_scalar`. + Expr::Literal( + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), + ) => contiguous_scalar += &v, + Expr::Literal(x) => { + return internal_err!( + "The scalar {x} should be casted to string type during the type coercion." + ) + } + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` (if it is not empty) and reset it to empty string. + // Then pushing this arg to the `new_args`. + arg => { + if !contiguous_scalar.is_empty() { + new_args.push(lit(contiguous_scalar)); + contiguous_scalar = "".to_string(); + } + new_args.push(arg); + } + } + } + + if !contiguous_scalar.is_empty() { + new_args.push(lit(contiguous_scalar)); + } + + if !args.eq(&new_args) { + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(concat()), + args: new_args, + }, + ))) + } else { + Ok(ExprSimplifyResult::Original(args)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::test::test_function; + use arrow::array::Array; + use arrow::array::{ArrayRef, StringArray}; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::from("bb")), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aabbcc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aacc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } + + #[test] + fn concat() -> Result<()> { + let c0 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = ConcatFunc::new().invoke(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + Ok(()) + } +} diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs new file mode 100644 index 000000000000..1d27712b2c93 --- /dev/null +++ b/datafusion/functions/src/string/concat_ws.rs @@ -0,0 +1,423 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::StringArray; +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Utf8; + +use datafusion_common::cast::as_string_array; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; +use crate::string::concat::simplify_concat; +use crate::string::concat_ws; + +#[derive(Debug)] +pub struct ConcatWsFunc { + signature: Signature, +} + +impl Default for ConcatWsFunc { + fn default() -> Self { + ConcatWsFunc::new() + } +} + +impl ConcatWsFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ConcatWsFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "concat_ws" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. + /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // do not accept 0 or 1 arguments. + if args.len() < 2 { + return exec_err!( + "concat_ws was called with {} arguments. It requires at least 2.", + args.len() + ); + } + + let array_len = args + .iter() + .filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .next(); + + // Scalar + if array_len.is_none() { + let sep = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => unreachable!(), + }; + + let mut result = String::new(); + let iter = &mut args[1..].iter(); + + for arg in iter.by_ref() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + result.push_str(s); + break; + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + _ => unreachable!(), + } + } + + for arg in iter.by_ref() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + result.push_str(sep); + result.push_str(s); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + _ => unreachable!(), + } + } + + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + } + + // Array + let len = array_len.unwrap(); + let mut data_size = 0; + + // parse sep + let sep = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); // estimate + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + _ => unreachable!(), + }; + + let mut columns = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + } + _ => unreachable!(), + } + } + + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset(); + continue; + } + + let mut iter = columns.iter(); + for column in iter.by_ref() { + if column.is_valid(i) { + builder.write::(column, i); + break; + } + } + + for column in iter { + if column.is_valid(i) { + builder.write::(&sep, i); + builder.write::(column, i); + } + } + + builder.append_offset(); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) + } + + /// Simply the `concat_ws` function by + /// 1. folding to `null` if the delimiter is null + /// 2. filtering out `null` arguments + /// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string + /// 4. concatenating contiguous literals if the delimiter is a literal. + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + match &args[..] { + [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals), + _ => Ok(ExprSimplifyResult::Original(args)), + } + } +} + +fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { + match delimiter { + Expr::Literal( + ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), + ) => { + match delimiter { + // when the delimiter is an empty string, + // we can use `concat` to replace `concat_ws` + Some(delimiter) if delimiter.is_empty() => simplify_concat(args.to_vec()), + Some(delimiter) => { + let mut new_args = Vec::with_capacity(args.len()); + new_args.push(lit(delimiter)); + let mut contiguous_scalar = None; + for arg in args { + match arg { + // filter out null args + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { + match contiguous_scalar { + None => contiguous_scalar = Some(v.to_string()), + Some(mut pre) => { + pre += delimiter; + pre += v; + contiguous_scalar = Some(pre) + } + } + } + Expr::Literal(s) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` and reset it to None. + // Then pushing this arg to the `new_args`. + arg => { + if let Some(val) = contiguous_scalar { + new_args.push(lit(val)); + } + new_args.push(arg.clone()); + contiguous_scalar = None; + } + } + } + if let Some(val) = contiguous_scalar { + new_args.push(lit(val)); + } + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(concat_ws()), + args: new_args, + }, + ))) + } + // if the delimiter is null, then the value of the whole expression is null. + None => Ok(ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::Utf8(None), + ))), + } + } + Expr::Literal(d) => internal_err!( + "The scalar {d} should be casted to string type during the type coercion." + ), + _ => { + let mut args = args + .iter() + .filter(|&x| !is_null(x)) + .cloned() + .collect::>(); + args.insert(0, delimiter.clone()); + Ok(ExprSimplifyResult::Original(args)) + } + } +} + +fn is_null(expr: &Expr) -> bool { + match expr { + Expr::Literal(v) => v.is_null(), + _ => false, + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use crate::string::concat_ws::ConcatWsFunc; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + ConcatWsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("|")), + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::from("bb")), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aa|bb|cc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("|")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::from("bb")), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("|")), + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aa|cc")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } + + #[test] + fn concat_ws() -> Result<()> { + // sep is scalar + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = ConcatWsFunc::new().invoke(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + + // sep is nullable array + let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some(","), + None, + Some("+"), + ]))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + Some("y"), + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = ConcatWsFunc::new().invoke(args)?; + let expected = + Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) + as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs new file mode 100644 index 000000000000..b72cf0f66fa6 --- /dev/null +++ b/datafusion/functions/src/string/ends_with.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Boolean; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct EndsWithFunc { + signature: Signature, +} + +impl Default for EndsWithFunc { + fn default() -> Self { + EndsWithFunc::new() + } +} + +impl EndsWithFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for EndsWithFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ends_with" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(ends_with::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(ends_with::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function ends_with") + } + } + } +} + +/// Returns true if string ends with suffix. +/// ends_with('alphabet', 'abet') = 't' +pub fn ends_with(args: &[ArrayRef]) -> Result { + let left = as_generic_string_array::(&args[0])?; + let right = as_generic_string_array::(&args[1])?; + + let result = arrow::compute::kernels::comparison::ends_with(left, right)?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, BooleanArray}; + use arrow::datatypes::DataType::Boolean; + + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::ends_with::EndsWithFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + EndsWithFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("alph")), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWithFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("bet")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWithFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("alph")), + ], + Ok(None), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWithFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + bool, + Boolean, + BooleanArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs new file mode 100644 index 000000000000..864179d130fd --- /dev/null +++ b/datafusion/functions/src/string/initcap.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub struct InitcapFunc { + signature: Signature, +} + +impl Default for InitcapFunc { + fn default() -> Self { + InitcapFunc::new() + } +} + +impl InitcapFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for InitcapFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "initcap" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "initcap") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(initcap::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function initcap") + } + } + } +} + +/// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. +/// initcap('hi THOMAS') = 'Hi Thomas' +fn initcap(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + let mut char_vector = Vec::::new(); + let mut previous_character_letter_or_number = false; + for c in string.chars() { + if previous_character_letter_or_number { + char_vector.push(c.to_ascii_lowercase()); + } else { + char_vector.push(c.to_ascii_uppercase()); + } + previous_character_letter_or_number = c.is_ascii_uppercase() + || c.is_ascii_lowercase() + || c.is_ascii_digit(); + } + char_vector.iter().collect::() + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::string::initcap::InitcapFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], + Ok(Some("Hi Thomas")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from(""))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from(""))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 81639c45f7ff..9eb2a7426fba 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -26,6 +26,10 @@ mod bit_length; mod btrim; mod chr; mod common; +mod concat; +mod concat_ws; +mod ends_with; +mod initcap; mod levenshtein; mod lower; mod ltrim; @@ -45,6 +49,10 @@ make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); make_udf_function!(chr::ChrFunc, CHR, chr); +make_udf_function!(concat::ConcatFunc, CONCAT, concat); +make_udf_function!(concat_ws::ConcatWsFunc, CONCAT_WS, concat_ws); +make_udf_function!(ends_with::EndsWithFunc, ENDS_WITH, ends_with); +make_udf_function!(initcap::InitcapFunc, INITCAP, initcap); make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); make_udf_function!(lower::LowerFunc, LOWER, lower); @@ -82,6 +90,28 @@ pub mod expr_fn { super::chr().call(vec![arg]) } + #[doc = "Concatenates the text representations of all the arguments. NULL arguments are ignored"] + pub fn concat(args: Vec) -> Expr { + super::concat().call(args) + } + + #[doc = "Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored."] + pub fn concat_ws(delimiter: Expr, args: Vec) -> Expr { + let mut args = args; + args.insert(0, delimiter); + super::concat_ws().call(args) + } + + #[doc = "Returns true if the `string` ends with the `suffix`, false otherwise."] + pub fn ends_with(string: Expr, suffix: Expr) -> Expr { + super::ends_with().call(vec![string, suffix]) + } + + #[doc = "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"] + pub fn initcap(string: Expr) -> Expr { + super::initcap().call(vec![string]) + } + #[doc = "Returns the Levenshtein distance between the two given strings"] pub fn levenshtein(arg1: Expr, arg2: Expr) -> Expr { super::levenshtein().call(vec![arg1, arg2]) @@ -160,6 +190,10 @@ pub fn functions() -> Vec> { bit_length(), btrim(), chr(), + concat(), + concat_ws(), + ends_with(), + initcap(), levenshtein(), lower(), ltrim(), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 1ea8b9534e80..ac96decbdd80 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -757,8 +757,9 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod test { use std::any::Any; - use std::sync::{Arc, OnceLock}; + use std::sync::Arc; + use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::tree_node::{TransformedResult, TreeNode}; @@ -766,10 +767,10 @@ mod test { use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::{ - cast, col, concat, concat_ws, create_udaf, is_true, lit, - AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, Case, - ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, - ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Volatility, + cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, + AggregateFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr, + ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, + Signature, SimpleAggregateUDF, Subquery, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -821,10 +822,11 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } - static TEST_SIGNATURE: OnceLock = OnceLock::new(); + #[derive(Debug, Clone)] + struct TestScalarUDF { + signature: Signature, + } - #[derive(Debug, Clone, Default)] - struct TestScalarUDF {} impl ScalarUDFImpl for TestScalarUDF { fn as_any(&self) -> &dyn Any { self @@ -833,11 +835,11 @@ mod test { fn name(&self) -> &str { "TestScalarUDF" } + fn signature(&self) -> &Signature { - TEST_SIGNATURE.get_or_init(|| { - Signature::uniform(1, vec![DataType::Float32], Volatility::Stable) - }) + &self.signature } + fn return_type(&self, _args: &[DataType]) -> Result { Ok(DataType::Utf8) } @@ -851,7 +853,10 @@ mod test { fn scalar_udf() -> Result<()> { let empty = empty(); - let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]); + let udf = ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + }) + .call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; @@ -861,7 +866,10 @@ mod test { #[test] fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); - let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); + let udf = ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + }) + .call(vec![lit("Apple")]); let plan_err = Projection::try_new(vec![udf], empty) .expect_err("Expected an error due to incorrect function input"); @@ -876,7 +884,9 @@ mod test { // test that automatic argument type coercion for scalar functions work let empty = empty(); let lit_expr = lit(10i64); - let fun = ScalarUDF::new_from_impl(TestScalarUDF {}); + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + }); let scalar_function_expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); let plan = LogicalPlan::Projection(Projection::try_new( @@ -1233,24 +1243,16 @@ mod test { let empty = empty_with_type(DataType::Utf8); let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; - // concat + // concat-type signature { - let expr = concat(&args); - + let expr = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + }) + .call(args.to_vec()); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty.clone())?); let expected = - "Projection: concat(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; - } - - // concat_ws - { - let expr = concat_ws(lit("-"), args.to_vec()); - - let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: concat_ws(Utf8(\"-\"), a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; + "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3198807b04cf..bb14f75446df 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,18 +21,12 @@ use std::borrow::Cow; use std::collections::HashSet; use std::ops::Not; -use super::inlist_simplifier::ShortenInListSimplifier; -use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::guarantees::GuaranteeRewriter; -use crate::simplify_expressions::regex::simplify_regex_expr; -use crate::simplify_expressions::SimplifyInfo; - use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; + use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, @@ -43,12 +37,20 @@ use datafusion_common::{ use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, - Operator, ScalarFunctionDefinition, Volatility, + and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, + ScalarFunctionDefinition, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + +use super::inlist_simplifier::ShortenInListSimplifier; +use super::utils::*; + /// This structure handles API for expression simplification /// /// Provides simplification information based on DFSchema and @@ -1304,7 +1306,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Do a first pass at simplification out_expr.rewrite(self)? } - Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::UDF(udf), args, @@ -1318,29 +1319,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), }, - // concat - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), - args, - }) => Transformed::yes(simpl_concat(args)?), - - // concat_ws - Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ConcatWithSeparator, - ), - args, - }) => match &args[..] { - [delimiter, vals @ ..] => { - Transformed::yes(simpl_concat_ws(delimiter, vals)?) - } - _ => Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ConcatWithSeparator, - args, - ))), - }, - // // Rules for Between // @@ -1712,15 +1690,17 @@ mod tests { sync::Arc, }; - use super::*; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{assert_contains, ToDFSchema}; use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::execution_props::ExecutionProps; + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; + + use super::*; + // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -2653,95 +2633,6 @@ mod tests { assert_eq!(simplify(expr_eq), lit(true)); } - #[test] - fn test_simplify_concat_ws() { - let null = lit(ScalarValue::Utf8(None)); - // the delimiter is not a literal - { - let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]); - let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]); - assert_eq!(simplify(expr), expected); - } - - // the delimiter is an empty string - { - let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]); - let expected = concat(&[col("a"), lit("cb")]); - assert_eq!(simplify(expr), expected); - } - - // the delimiter is a not-empty string - { - let expr = concat_ws( - lit("-"), - vec![ - null.clone(), - col("c0"), - lit("hello"), - null.clone(), - lit("rust"), - col("c1"), - lit(""), - lit(""), - null, - ], - ); - let expected = concat_ws( - lit("-"), - vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")], - ); - assert_eq!(simplify(expr), expected) - } - } - - #[test] - fn test_simplify_concat_ws_with_null() { - let null = lit(ScalarValue::Utf8(None)); - // null delimiter -> null - { - let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); - assert_eq!(simplify(expr), null); - } - - // filter out null args - { - let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]); - let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]); - assert_eq!(simplify(expr), expected); - } - - // nested test - { - let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); - let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]); - assert_eq!(simplify(expr), concat_ws(lit("|"), vec![col("c3")])); - } - - // null delimiter (nested) - { - let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); - let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); - assert_eq!(simplify(expr), null); - } - } - - #[test] - fn test_simplify_concat() { - let null = lit(ScalarValue::Utf8(None)); - let expr = concat(&[ - null.clone(), - col("c0"), - lit("hello "), - null.clone(), - lit("rust"), - col("c1"), - lit(""), - null, - ]); - let expected = concat(&[col("c0"), lit("hello rust"), col("c1")]); - assert_eq!(simplify(expr), expected) - } - #[test] fn test_simplify_regex() { // malformed regex diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index f0ad4738633f..5da727cb5990 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -19,9 +19,9 @@ use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ - expr::{Between, BinaryExpr, InList, ScalarFunction}, - expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, + expr::{Between, BinaryExpr, InList}, + expr_fn::{and, bitwise_and, bitwise_or, or}, + Expr, Like, Operator, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -341,120 +341,3 @@ pub fn distribute_negation(expr: Expr) -> Expr { _ => Expr::Negative(Box::new(expr)), } } - -/// Simplify the `concat` function by -/// 1. filtering out all `null` literals -/// 2. concatenating contiguous literal arguments -/// -/// For example: -/// `concat(col(a), 'hello ', 'world', col(b), null)` -/// will be optimized to -/// `concat(col(a), 'hello world', col(b))` -pub fn simpl_concat(args: Vec) -> Result { - let mut new_args = Vec::with_capacity(args.len()); - let mut contiguous_scalar = "".to_string(); - for arg in args { - match arg { - // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} - // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. - // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), - ) => contiguous_scalar += &v, - Expr::Literal(x) => { - return internal_err!( - "The scalar {x} should be casted to string type during the type coercion." - ) - } - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` (if it is not empty) and reset it to empty string. - // Then pushing this arg to the `new_args`. - arg => { - if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); - contiguous_scalar = "".to_string(); - } - new_args.push(arg); - } - } - } - if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); - } - - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Concat, - new_args, - ))) -} - -/// Simply the `concat_ws` function by -/// 1. folding to `null` if the delimiter is null -/// 2. filtering out `null` arguments -/// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string -/// 4. concatenating contiguous literals if the delimiter is a literal. -pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { - match delimiter { - Expr::Literal( - ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), - ) => { - match delimiter { - // when the delimiter is an empty string, - // we can use `concat` to replace `concat_ws` - Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()), - Some(delimiter) => { - let mut new_args = Vec::with_capacity(args.len()); - new_args.push(lit(delimiter)); - let mut contiguous_scalar = None; - for arg in args { - match arg { - // filter out null args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { - match contiguous_scalar { - None => contiguous_scalar = Some(v.to_string()), - Some(mut pre) => { - pre += delimiter; - pre += v; - contiguous_scalar = Some(pre) - } - } - } - Expr::Literal(s) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` and reset it to None. - // Then pushing this arg to the `new_args`. - arg => { - if let Some(val) = contiguous_scalar { - new_args.push(lit(val)); - } - new_args.push(arg.clone()); - contiguous_scalar = None; - } - } - } - if let Some(val) = contiguous_scalar { - new_args.push(lit(val)); - } - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::ConcatWithSeparator, - new_args, - ))) - } - // if the delimiter is null, then the value of the whole expression is null. - None => Ok(Expr::Literal(ScalarValue::Utf8(None))), - } - } - Expr::Literal(d) => internal_err!( - "The scalar {d} should be casted to string type during the type coercion." - ), - d => Ok(concat_ws( - d.clone(), - args.iter() - .filter(|&x| !is_null(x)) - .cloned() - .collect::>(), - )), - } -} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 01db5e817c56..dcaadaa8209c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -206,32 +206,6 @@ fn between_date64_plus_interval() -> Result<()> { Ok(()) } -#[test] -fn concat_literals() -> Result<()> { - let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ - AS col - FROM test"; - let plan = test_sql(sql)?; - let expected = - "Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); - Ok(()) -} - -#[test] -fn concat_ws_literals() -> Result<()> { - let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \ - AS col - FROM test"; - let plan = test_sql(sql)?; - let expected = - "Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); - Ok(()) -} - #[test] fn propagate_empty_relation() { let sql = "SELECT test.col_int32 FROM test JOIN ( SELECT col_int32 FROM test WHERE false ) AS ta1 ON test.col_int32 = ta1.col_int32;"; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index ba8d237bb276..fe72a7a46fcb 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -77,7 +77,3 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } [[bench]] harness = false name = "in_list" - -[[bench]] -harness = false -name = "concat" diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 6efbc4179ff4..656ce711a0b0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -33,29 +33,24 @@ use std::ops::Neg; use std::sync::Arc; -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Schema}, -}; +use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::Array; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; -use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, }; +use datafusion_expr::{Expr, ScalarFunctionDefinition, ScalarUDF}; use crate::sort_properties::SortProperties; -use crate::{ - conditional_expressions, string_expressions, PhysicalExpr, ScalarFunctionExpr, -}; +use crate::{conditional_expressions, PhysicalExpr, ScalarFunctionExpr}; /// Create a physical (function) expression. /// This function errors when `args`' can't be coerced to a valid argument type of the function. -pub fn create_physical_expr( +pub fn create_builtin_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, @@ -84,6 +79,38 @@ pub fn create_physical_expr( ))) } +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_physical_expr( + fun: &ScalarUDF, + input_phy_exprs: &[Arc], + input_schema: &Schema, + args: &[Expr], + input_dfschema: &DFSchema, +) -> Result> { + let input_expr_types = input_phy_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, fun.signature())?; + + // Since we have arg_types, we don't need args and schema. + let return_type = + fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); + Ok(Arc::new(ScalarFunctionExpr::new( + fun.name(), + fun_def, + input_phy_exprs.to_vec(), + return_type, + fun.monotonicity()?, + fun.signature().type_signature.supports_zero_argument(), + ))) +} + #[derive(Debug, Clone, Copy)] pub enum Hint { /// Indicates the argument needs to be padded if it is scalar @@ -179,32 +206,6 @@ pub fn create_physical_fun( Ok(match fun { // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), - BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), - BuiltinScalarFunction::ConcatWithSeparator => { - Arc::new(string_expressions::concat_ws) - } - BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::initcap::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::initcap::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function initcap") - } - }), - BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::ends_with::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::ends_with::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function ends_with") - } - }), }) } @@ -272,219 +273,51 @@ fn func_order_in_one_dimension( #[cfg(test)] mod tests { use arrow::{ - array::{Array, ArrayRef, BooleanArray, Int32Array, StringArray, UInt64Array}, - datatypes::Field, - record_batch::RecordBatch, + array::{Array, ArrayRef, UInt64Array}, + datatypes::{DataType, Field}, }; + use arrow_schema::DataType::Utf8; use datafusion_common::cast::as_uint64_array; use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; - use datafusion_expr::Signature; + use datafusion_expr::{Signature, Volatility}; - use crate::expressions::lit; use crate::expressions::try_cast; + use crate::utils::tests::TestScalarUDF; use super::*; - /// $FUNC function to test - /// $ARGS arguments (vec) to pass to function - /// $EXPECTED a Result> where Result allows testing errors and Option allows testing Null - /// $EXPECTED_TYPE is the expected value type - /// $DATA_TYPE is the function to test result type - /// $ARRAY_TYPE is the column type after function applied - macro_rules! test_function { - ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { - // used to provide type annotation - let expected: Result> = $EXPECTED; - let execution_props = ExecutionProps::new(); - - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - let expr = - create_physical_expr_with_type_coercion(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - - match expected { - Ok(expected) => { - let result = expr.evaluate(&batch)?; - let result = result.into_array(batch.num_rows()).expect("Failed to convert to array"); - let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - - // value is correct - match expected { - Some(v) => assert_eq!(result.value(0), v), - None => assert!(result.is_null(0)), - }; - } - Err(expected_error) => { - // evaluate is expected error - cannot use .expect_err() due to Debug not being implemented - match expr.evaluate(&batch) { - Ok(_) => assert!(false, "expected error"), - Err(error) => { - assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); - } - } - } - }; - }; - } - - #[test] - fn test_functions() -> Result<()> { - test_function!( - Concat, - &[lit("aa"), lit("bb"), lit("cc"),], - Ok(Some("aabbcc")), - &str, - Utf8, - StringArray - ); - test_function!( - Concat, - &[lit("aa"), lit(ScalarValue::Utf8(None)), lit("cc"),], - Ok(Some("aacc")), - &str, - Utf8, - StringArray - ); - test_function!( - Concat, - &[lit(ScalarValue::Utf8(None))], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[lit("|"), lit("aa"), lit("bb"), lit("cc"),], - Ok(Some("aa|bb|cc")), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[lit("|"), lit(ScalarValue::Utf8(None)),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[ - lit(ScalarValue::Utf8(None)), - lit("aa"), - lit("bb"), - lit("cc"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - ConcatWithSeparator, - &[lit("|"), lit("aa"), lit(ScalarValue::Utf8(None)), lit("cc"),], - Ok(Some("aa|cc")), - &str, - Utf8, - StringArray - ); - test_function!( - InitCap, - &[lit("hi THOMAS")], - Ok(Some("Hi Thomas")), - &str, - Utf8, - StringArray - ); - test_function!(InitCap, &[lit("")], Ok(Some("")), &str, Utf8, StringArray); - test_function!(InitCap, &[lit("")], Ok(Some("")), &str, Utf8, StringArray); - test_function!( - InitCap, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - EndsWith, - &[lit("alphabet"), lit("alph"),], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - EndsWith, - &[lit("alphabet"), lit("bet"),], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - EndsWith, - &[lit(ScalarValue::Utf8(None)), lit("alph"),], - Ok(None), - bool, - Boolean, - BooleanArray - ); - test_function!( - EndsWith, - &[lit("alphabet"), lit(ScalarValue::Utf8(None)),], - Ok(None), - bool, - Boolean, - BooleanArray - ); - - Ok(()) - } - #[test] fn test_empty_arguments_error() -> Result<()> { - let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let udf = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + }); + let expr = create_physical_expr_with_type_coercion( + &udf, + &[], + &schema, + &[], + &DFSchema::empty(), + ); - // pick some arbitrary functions to test - let funs = [BuiltinScalarFunction::Concat]; - - for fun in funs.iter() { - let expr = create_physical_expr_with_type_coercion( - fun, - &[], - &schema, - &execution_props, - ); - - match expr { - Ok(..) => { - return plan_err!( - "Builtin scalar function {fun} does not support empty arguments" - ); - } - Err(DataFusionError::Plan(_)) => { - // Continue the loop - } - Err(..) => { - return internal_err!( - "Builtin scalar function {fun} didn't got the right error with empty arguments"); - } + match expr { + Ok(..) => { + return plan_err!( + "ScalarUDF function {udf:?} does not support empty arguments" + ); + } + Err(DataFusionError::Plan(_)) => { + // Continue the loop + } + Err(..) => { + return internal_err!( + "ScalarUDF function {udf:?} didn't got the right error with empty arguments"); } } + Ok(()) } @@ -517,14 +350,21 @@ mod tests { // Helper function just for testing. // The type coercion will be done in the logical phase, should do the type coercion for the test fn create_physical_expr_with_type_coercion( - fun: &BuiltinScalarFunction, + fun: &ScalarUDF, input_phy_exprs: &[Arc], input_schema: &Schema, - execution_props: &ExecutionProps, + args: &[Expr], + input_dfschema: &DFSchema, ) -> Result> { let type_coerced_phy_exprs = - coerce(input_phy_exprs, input_schema, &fun.signature()).unwrap(); - create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props) + coerce(input_phy_exprs, input_schema, fun.signature()).unwrap(); + create_physical_expr( + fun, + &type_coerced_phy_exprs, + input_schema, + args, + input_dfschema, + ) } fn dummy_function(args: &[ArrayRef]) -> Result { diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 7b81e8f8a5c4..aabcf42fe7c4 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -28,7 +28,6 @@ mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; -pub mod string_expressions; pub mod udf; pub mod utils; pub mod window; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index aefbd54f8e99..20626818c83b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -307,7 +307,7 @@ pub fn create_physical_expr( match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { - functions::create_physical_expr( + functions::create_builtin_physical_expr( fun, &physical_args, input_schema, diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs deleted file mode 100644 index fd6c8eb6b1d9..000000000000 --- a/datafusion/physical-expr/src/string_expressions.rs +++ /dev/null @@ -1,495 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Some of these functions reference the Postgres documentation -// or implementation to ensure compatibility and are subject to -// the Postgres license. - -//! String expressions - -use std::sync::Arc; - -use arrow::array::ArrayDataBuilder; -use arrow::{ - array::{ - Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, - StringArray, - }, - datatypes::DataType, -}; -use arrow_buffer::{MutableBuffer, NullBuffer}; - -use datafusion_common::Result; -use datafusion_common::{ - cast::{as_generic_string_array, as_string_array}, - exec_err, ScalarValue, -}; -use datafusion_expr::ColumnarValue; - -enum ColumnarValueRef<'a> { - Scalar(&'a [u8]), - NullableArray(&'a StringArray), - NonNullableArray(&'a StringArray), -} - -impl<'a> ColumnarValueRef<'a> { - #[inline] - fn is_valid(&self, i: usize) -> bool { - match &self { - Self::Scalar(_) | Self::NonNullableArray(_) => true, - Self::NullableArray(array) => array.is_valid(i), - } - } - - #[inline] - fn nulls(&self) -> Option { - match &self { - Self::Scalar(_) | Self::NonNullableArray(_) => None, - Self::NullableArray(array) => array.nulls().cloned(), - } - } -} - -/// Optimized version of the StringBuilder in Arrow that: -/// 1. Precalculating the expected length of the result, avoiding reallocations. -/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` -struct StringArrayBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, -} - -impl StringArrayBuilder { - fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i32) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - } - } - - fn write(&mut self, column: &ColumnarValueRef, i: usize) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - } - - fn append_offset(&mut self) { - let next_offset: i32 = self - .value_buffer - .len() - .try_into() - .expect("byte array offset overflow"); - unsafe { self.offsets_buffer.push_unchecked(next_offset) }; - } - - fn finish(self, null_buffer: Option) -> StringArray { - let array_builder = ArrayDataBuilder::new(DataType::Utf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - // SAFETY: all data that was appended was valid UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - StringArray::from(array_data) - } -} - -/// Concatenates the text representations of all the arguments. NULL arguments are ignored. -/// concat('abcde', 2, NULL, 22) = 'abcde222' -pub fn concat(args: &[ColumnarValue]) -> Result { - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); - - // Scalar - if array_len.is_none() { - let mut result = String::new(); - for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); - } - } - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); - } - - // Array - let len = array_len.unwrap(); - let mut data_size = 0; - let mut columns = Vec::with_capacity(args.len()); - - for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); - } - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) - }; - columns.push(column); - } - _ => unreachable!(), - } - } - - let mut builder = StringArrayBuilder::with_capacity(len, data_size); - for i in 0..len { - columns - .iter() - .for_each(|column| builder.write::(column, i)); - builder.append_offset(); - } - Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) -} - -/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. -/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' -pub fn concat_ws(args: &[ColumnarValue]) -> Result { - // do not accept 0 or 1 arguments. - if args.len() < 2 { - return exec_err!( - "concat_ws was called with {} arguments. It requires at least 2.", - args.len() - ); - } - - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); - - // Scalar - if array_len.is_none() { - let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); - } - _ => unreachable!(), - }; - - let mut result = String::new(); - let iter = &mut args[1..].iter(); - - for arg in iter.by_ref() { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - result.push_str(s); - break; - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} - _ => unreachable!(), - } - } - - for arg in iter.by_ref() { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - result.push_str(sep); - result.push_str(s); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} - _ => unreachable!(), - } - } - - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); - } - - // Array - let len = array_len.unwrap(); - let mut data_size = 0; - - // parse sep - let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); // estimate - if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) - } - } - _ => unreachable!(), - }; - - let mut columns = Vec::with_capacity(args.len() - 1); - for arg in &args[1..] { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); - } - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) - }; - columns.push(column); - } - _ => unreachable!(), - } - } - - let mut builder = StringArrayBuilder::with_capacity(len, data_size); - for i in 0..len { - if !sep.is_valid(i) { - builder.append_offset(); - continue; - } - - let mut iter = columns.iter(); - for column in iter.by_ref() { - if column.is_valid(i) { - builder.write::(column, i); - break; - } - } - - for column in iter { - if column.is_valid(i) { - builder.write::(&sep, i); - builder.write::(column, i); - } - } - - builder.append_offset(); - } - - Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) -} - -/// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. -/// initcap('hi THOMAS') = 'Hi Thomas' -pub fn initcap(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - // first map is the iterator, second is for the `Option<_>` - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - let mut char_vector = Vec::::new(); - let mut previous_character_letter_or_number = false; - for c in string.chars() { - if previous_character_letter_or_number { - char_vector.push(c.to_ascii_lowercase()); - } else { - char_vector.push(c.to_ascii_uppercase()); - } - previous_character_letter_or_number = c.is_ascii_uppercase() - || c.is_ascii_lowercase() - || c.is_ascii_digit(); - } - char_vector.iter().collect::() - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns the position of the first occurrence of substring in string. -/// The position is counted from 1. If the substring is not found, returns 0. -/// For example, instr('Helloworld', 'world') = 6. -pub fn instr(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let substr_array = as_generic_string_array::(&args[1])?; - - match args[0].data_type() { - DataType::Utf8 => { - let result = string_array - .iter() - .zip(substr_array.iter()) - .map(|(string, substr)| match (string, substr) { - (Some(string), Some(substr)) => string - .find(substr) - .map_or(Some(0), |index| Some((index + 1) as i32)), - _ => None, - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) - } - DataType::LargeUtf8 => { - let result = string_array - .iter() - .zip(substr_array.iter()) - .map(|(string, substr)| match (string, substr) { - (Some(string), Some(substr)) => string - .find(substr) - .map_or(Some(0), |index| Some((index + 1) as i64)), - _ => None, - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "instr was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." - ) - } - } -} - -/// Returns true if string starts with prefix. -/// starts_with('alphabet', 'alph') = 't' -pub fn starts_with(args: &[ArrayRef]) -> Result { - let left = as_generic_string_array::(&args[0])?; - let right = as_generic_string_array::(&args[1])?; - - let result = arrow::compute::kernels::comparison::starts_with(left, right)?; - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns true if string ends with suffix. -/// ends_with('alphabet', 'abet') = 't' -pub fn ends_with(args: &[ArrayRef]) -> Result { - let left = as_generic_string_array::(&args[0])?; - let right = as_generic_string_array::(&args[1])?; - - let result = arrow::compute::kernels::comparison::ends_with(left, right)?; - - Ok(Arc::new(result) as ArrayRef) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn concat() -> Result<()> { - let c0 = - ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); - let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some("x"), - None, - Some("z"), - ]))); - let args = &[c0, c1, c2]; - - let result = super::concat(args)?; - let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; - match &result { - ColumnarValue::Array(array) => { - assert_eq!(&expected, array); - } - _ => panic!(), - } - Ok(()) - } - - #[test] - fn concat_ws() -> Result<()> { - // sep is scalar - let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); - let c1 = - ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some("x"), - None, - Some("z"), - ]))); - let args = &[c0, c1, c2]; - - let result = super::concat_ws(args)?; - let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; - match &result { - ColumnarValue::Array(array) => { - assert_eq!(&expected, array); - } - _ => panic!(), - } - - // sep is nullable array - let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some(","), - None, - Some("+"), - ]))); - let c1 = - ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some("x"), - Some("y"), - Some("z"), - ]))); - let args = &[c0, c1, c2]; - - let result = super::concat_ws(args)?; - let expected = - Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) - as ArrayRef; - match &result { - ColumnarValue::Array(array) => { - assert_eq!(&expected, array); - } - _ => panic!(), - } - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index d7bebbff891c..a0d6436586a2 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -276,7 +276,7 @@ pub(crate) mod tests { #[derive(Debug, Clone)] pub struct TestScalarUDF { - signature: Signature, + pub(crate) signature: Signature, } impl TestScalarUDF { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6578c64cff1f..13709bf394bf 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -567,11 +567,11 @@ enum ScalarFunction { // 23 was Btrim // 24 was CharacterLength // 25 was Chr - Concat = 26; - ConcatWithSeparator = 27; + // 26 was Concat + // 27 was ConcatWithSeparator // 28 was DatePart // 29 was DateTrunc - InitCap = 30; + // 30 was InitCap // 31 was Left // 32 was Lpad // 33 was Lower @@ -670,7 +670,7 @@ enum ScalarFunction { // 128 was ArraySort // 129 was ArrayDistinct // 130 was ArrayResize - EndsWith = 131; + // 131 was EndsWith // 132 was InStr // 133 was MakeDate // 134 was ArrayReverse diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 1546d75f2acd..3a2be9907354 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22792,11 +22792,7 @@ impl serde::Serialize for ScalarFunction { { let variant = match self { Self::Unknown => "unknown", - Self::Concat => "Concat", - Self::ConcatWithSeparator => "ConcatWithSeparator", - Self::InitCap => "InitCap", Self::Coalesce => "Coalesce", - Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) } @@ -22809,11 +22805,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { const FIELDS: &[&str] = &[ "unknown", - "Concat", - "ConcatWithSeparator", - "InitCap", "Coalesce", - "EndsWith", ]; struct GeneratedVisitor; @@ -22855,11 +22847,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { match value { "unknown" => Ok(ScalarFunction::Unknown), - "Concat" => Ok(ScalarFunction::Concat), - "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), - "InitCap" => Ok(ScalarFunction::InitCap), "Coalesce" => Ok(ScalarFunction::Coalesce), - "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c752743cbdce..487cfe01fba5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2867,11 +2867,11 @@ pub enum ScalarFunction { /// 23 was Btrim /// 24 was CharacterLength /// 25 was Chr - Concat = 26, - ConcatWithSeparator = 27, + /// 26 was Concat + /// 27 was ConcatWithSeparator /// 28 was DatePart /// 29 was DateTrunc - InitCap = 30, + /// 30 was InitCap /// 31 was Left /// 32 was Lpad /// 33 was Lower @@ -2904,7 +2904,7 @@ pub enum ScalarFunction { /// 60 was Translate /// Trim = 61; /// Upper = 62; - Coalesce = 63, + /// /// 64 was Power /// 65 was StructFun /// 66 was FromUnixtime @@ -2970,7 +2970,7 @@ pub enum ScalarFunction { /// 128 was ArraySort /// 129 was ArrayDistinct /// 130 was ArrayResize - /// + /// 131 was EndsWith /// 132 was InStr /// 133 was MakeDate /// 134 was ArrayReverse @@ -2978,7 +2978,7 @@ pub enum ScalarFunction { /// 136 was ToChar /// 137 was ToDate /// 138 was ToUnixtime - EndsWith = 131, + Coalesce = 63, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2988,22 +2988,14 @@ impl ScalarFunction { pub fn as_str_name(&self) -> &'static str { match self { ScalarFunction::Unknown => "unknown", - ScalarFunction::Concat => "Concat", - ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", - ScalarFunction::InitCap => "InitCap", ScalarFunction::Coalesce => "Coalesce", - ScalarFunction::EndsWith => "EndsWith", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "unknown" => Some(Self::Unknown), - "Concat" => Some(Self::Concat), - "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), - "InitCap" => Some(Self::InitCap), "Coalesce" => Some(Self::Coalesce), - "EndsWith" => Some(Self::EndsWith), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index e66bd1a5f0a9..4ccff9e7aa62 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,9 +37,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - coalesce, concat_expr, concat_ws_expr, ends_with, + coalesce, expr::{self, InList, Sort, WindowFunction}, - initcap, logical_plan::{PlanType, StringifiedPlan}, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -418,10 +417,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { use protobuf::ScalarFunction; match f { ScalarFunction::Unknown => todo!(), - ScalarFunction::Concat => Self::Concat, - ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, - ScalarFunction::EndsWith => Self::EndsWith, - ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Coalesce => Self::Coalesce, } } @@ -1287,19 +1282,6 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::InitCap => { - Ok(initcap(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Concat => { - Ok(concat_expr(parse_exprs(args, registry, codec)?)) - } - ScalarFunction::ConcatWithSeparator => { - Ok(concat_ws_expr(parse_exprs(args, registry, codec)?)) - } - ScalarFunction::EndsWith => Ok(ends_with( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 4916b4bed9a3..7ad39df2c7ed 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1407,10 +1407,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { fn try_from(scalar: &BuiltinScalarFunction) -> Result { let scalar_function = match scalar { - BuiltinScalarFunction::Concat => Self::Concat, - BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, - BuiltinScalarFunction::EndsWith => Self::EndsWith, - BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Coalesce => Self::Coalesce, }; diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 81e4c92ffc68..ffc165e725b0 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -352,7 +352,7 @@ pub fn parse_physical_expr( // TODO Do not create new the ExecutionProps let execution_props = ExecutionProps::new(); - functions::create_physical_expr( + functions::create_builtin_physical_expr( &(&scalar_function).into(), &args, input_schema, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 19288123558a..410cbdad747a 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -38,7 +38,7 @@ use datafusion_sql::{ planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; -use datafusion_functions::unicode; +use datafusion_functions::{string, unicode}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -2678,6 +2678,7 @@ fn logical_plan_with_dialect_and_options( ) -> Result { let context = MockContextProvider::default() .with_udf(unicode::character_length().as_ref().clone()) + .with_udf(string::concat().as_ref().clone()) .with_udf(make_udf( "nullif", vec![DataType::Int32, DataType::Int32],