Skip to content

Commit

Permalink
Move concat, concat_ws, ends_with, initcap to datafusion-functions (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 authored Apr 16, 2024
1 parent 1395adf commit 8730466
Show file tree
Hide file tree
Showing 31 changed files with 1,409 additions and 1,271 deletions.
5 changes: 4 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ use crate::{
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
variable::{VarProvider, VarType},
};
use crate::{functions, functions_aggregate, functions_array};

#[cfg(feature = "array_expressions")]
use crate::functions_array;
use crate::{functions, functions_aggregate};

use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use datafusion::assert_batches_eq;
use datafusion_common::DFSchema;
use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast, ExprSchemable};
use datafusion_functions::unicode::expr_fn::character_length;
use datafusion_functions_array::expr_fn::array_to_string;

fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Expand Down
28 changes: 27 additions & 1 deletion datafusion/core/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ fn timestamp_nano_ts_utc_predicates() {
assert_eq!(expected, format!("{plan:?}"));
}

#[test]
fn concat_literals() -> Result<()> {
let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
AS col
FROM test";
let expected =
"Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
\n TableScan: test projection=[col_int32, col_utf8]";
quick_test(sql, expected);
Ok(())
}

#[test]
fn concat_ws_literals() -> Result<()> {
let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \
AS col
FROM test";
let expected =
"Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\
\n TableScan: test projection=[col_int32, col_utf8]";
quick_test(sql, expected);
Ok(())
}

fn quick_test(sql: &str, expected_plan: &str) {
let plan = test_sql(sql).unwrap();
assert_eq!(expected_plan, format!("{:?}", plan));
Expand All @@ -97,7 +121,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
// create a logical query plan
let context_provider = MyContextProvider::default()
.with_udf(datetime::now())
.with_udf(datafusion_functions::core::arrow_cast());
.with_udf(datafusion_functions::core::arrow_cast())
.with_udf(datafusion_functions::string::concat())
.with_udf(datafusion_functions::string::concat_ws());
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

Expand Down
99 changes: 94 additions & 5 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_expr::{
expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan,
LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_functions::math;
use datafusion_functions::{math, string};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
Expand Down Expand Up @@ -217,7 +217,7 @@ fn fold_and_simplify() {
let info: MyInfo = schema().into();

// What will it do with the expression `concat('foo', 'bar') == 'foobar')`?
let expr = concat(&[lit("foo"), lit("bar")]).eq(lit("foobar"));
let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar"));

// Since datafusion applies both simplification *and* rewriting
// some expressions can be entirely simplified
Expand Down Expand Up @@ -364,13 +364,13 @@ fn test_const_evaluator() {
#[test]
fn test_const_evaluator_scalar_functions() {
// concat("foo", "bar") --> "foobar"
let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap();
let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]);
test_evaluate(expr, lit("foobar"));

// ensure arguments are also constant folded
// concat("foo", concat("bar", "baz")) --> "foobarbaz"
let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap();
let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap();
let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]);
let expr = string::expr_fn::concat(vec![lit("foo"), concat1]);
test_evaluate(expr, lit("foobarbaz"));

// Check non string arguments
Expand Down Expand Up @@ -569,3 +569,92 @@ fn test_simplify_power() {
test_simplify(expr, expected)
}
}

#[test]
fn test_simplify_concat_ws() {
let null = lit(ScalarValue::Utf8(None));
// the delimiter is not a literal
{
let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]);
let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]);
test_simplify(expr, expected);
}

// the delimiter is an empty string
{
let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]);
let expected = concat(vec![col("a"), lit("cb")]);
test_simplify(expr, expected);
}

// the delimiter is a not-empty string
{
let expr = concat_ws(
lit("-"),
vec![
null.clone(),
col("c0"),
lit("hello"),
null.clone(),
lit("rust"),
col("c1"),
lit(""),
lit(""),
null,
],
);
let expected = concat_ws(
lit("-"),
vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")],
);
test_simplify(expr, expected)
}
}

#[test]
fn test_simplify_concat_ws_with_null() {
let null = lit(ScalarValue::Utf8(None));
// null delimiter -> null
{
let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
test_simplify(expr, null.clone());
}

// filter out null args
{
let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]);
let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]);
test_simplify(expr, expected);
}

// nested test
{
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]);
test_simplify(expr, concat_ws(lit("|"), vec![col("c3")]));
}

// null delimiter (nested)
{
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]);
test_simplify(expr, null);
}
}

#[test]
fn test_simplify_concat() {
let null = lit(ScalarValue::Utf8(None));
let expr = concat(vec![
null.clone(),
col("c0"),
lit("hello "),
null.clone(),
lit("rust"),
col("c1"),
lit(""),
null,
]);
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
test_simplify(expr, expected)
}
90 changes: 1 addition & 89 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::str::FromStr;
use std::sync::OnceLock;

use crate::type_coercion::functions::data_types;
use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility};
use crate::{FuncMonotonicity, Signature, Volatility};

use arrow::datatypes::DataType;
use datafusion_common::{plan_err, DataFusionError, Result};
Expand All @@ -39,15 +39,6 @@ pub enum BuiltinScalarFunction {
// math functions
/// coalesce
Coalesce,
// string functions
/// concat
Concat,
/// concat_ws
ConcatWithSeparator,
/// ends_with
EndsWith,
/// initcap
InitCap,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -101,10 +92,6 @@ impl BuiltinScalarFunction {
match self {
// Immutable scalar builtins
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,
}
}

Expand All @@ -117,8 +104,6 @@ impl BuiltinScalarFunction {
/// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation.
/// 2. Deduce the output `DataType` based on the provided `input_expr_types`.
pub fn return_type(self, input_expr_types: &[DataType]) -> Result<DataType> {
use DataType::*;

// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.

Expand All @@ -130,43 +115,18 @@ impl BuiltinScalarFunction {
let coerced_types = data_types(input_expr_types, &self.signature());
coerced_types.map(|types| types[0].clone())
}
BuiltinScalarFunction::Concat => Ok(Utf8),
BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8),
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::EndsWith => Ok(Boolean),
}
}

/// Return the argument [`Signature`] supported by this function
pub fn signature(&self) -> Signature {
use DataType::*;
use TypeSignature::*;
// note: the physical expression must accept the type returned by this function or the execution panics.

// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::Concat
| BuiltinScalarFunction::ConcatWithSeparator => {
Signature::variadic(vec![Utf8], self.volatility())
}
BuiltinScalarFunction::Coalesce => {
Signature::variadic_equal(self.volatility())
}
BuiltinScalarFunction::InitCap => {
Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility())
}

BuiltinScalarFunction::EndsWith => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
),
}
}

Expand All @@ -182,11 +142,6 @@ impl BuiltinScalarFunction {
match self {
// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],

BuiltinScalarFunction::Concat => &["concat"],
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
BuiltinScalarFunction::EndsWith => &["ends_with"],
BuiltinScalarFunction::InitCap => &["initcap"],
}
}
}
Expand All @@ -208,49 +163,6 @@ impl FromStr for BuiltinScalarFunction {
}
}

/// Creates a function to identify the optimal return type of a string function given
/// the type of its first argument.
///
/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
/// `$largeUtf8Type`,
///
/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
macro_rules! get_optimal_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Ok(match arg_type {
// LargeBinary inputs are automatically coerced to Utf8
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
// Binary inputs are automatically coerced to Utf8
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
DataType::Dictionary(_, value_type) => match **value_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
_ => {
return plan_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
**value_type
);
}
},
data_type => {
return plan_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
data_type
);
}
})
}
};
}

// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 8730466

Please sign in to comment.