Skip to content

Commit

Permalink
Feat: Support array_join function (#1290)
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari authored Jan 23, 2025
1 parent b064930 commit 9a4e5b5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 0 deletions.
27 changes: 27 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_functions_nested::set_ops::array_intersect_udf;
use datafusion_functions_nested::string::array_to_string_udf;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};

use crate::execution::shuffle::CompressionCodec;
Expand Down Expand Up @@ -791,6 +792,32 @@ impl PhysicalPlanner {
));
Ok(array_intersect_expr)
}
ExprStruct::ArrayJoin(expr) => {
let array_expr =
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let delimiter_expr = self.create_expr(
expr.delimiter_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;

let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
if expr.null_replacement_expr.is_some() {
let null_replacement_expr = self.create_expr(
expr.null_replacement_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;
args.push(null_replacement_expr)
}

let datafusion_array_to_string = array_to_string_udf();
let array_join_expr = Arc::new(ScalarFunctionExpr::new(
"array_join",
datafusion_array_to_string,
args,
DataType::Utf8,
));
Ok(array_join_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
7 changes: 7 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ message Expr {
BinaryExpr array_contains = 60;
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
}
}

Expand Down Expand Up @@ -415,6 +416,12 @@ message ArrayInsert {
bool legacy_negative_index = 4;
}

message ArrayJoin {
Expr array_expr = 1;
Expr delimiter_expr = 2;
Expr null_replacement_expr = 3;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
32 changes: 32 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,38 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expr.children(1),
inputs,
(builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)

if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
val arrayJoinBuilder = nullReplacementExpr match {
case Some(nrExpr) =>
val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding)
ExprOuterClass.ArrayJoin
.newBuilder()
.setArrayExpr(arrayExprProto.get)
.setDelimiterExpr(delimiterExprProto.get)
.setNullReplacementExpr(nullReplacementExprProto.get)
case None =>
ExprOuterClass.ArrayJoin
.newBuilder()
.setArrayExpr(arrayExprProto.get)
.setDelimiterExpr(delimiterExprProto.get)
}
Some(
ExprOuterClass.Expr
.newBuilder()
.setArrayJoin(arrayJoinBuilder)
.build())
} else {
val exprs: List[Expression] = nullReplacementExpr match {
case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr)
case None => List(arrayExpr, delimiterExpr)
}
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
None
}
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
17 changes: 17 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2684,4 +2684,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("array_join") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
checkSparkAnswerAndOperator(sql(
"SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1"))
checkSparkAnswerAndOperator(sql(
"SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1"))
checkSparkAnswerAndOperator(sql(
"SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null"))
checkSparkAnswerAndOperator(
sql("SELECT array_join(array('hello', '-', 'world', cast(_2 as string)), ' ') from t1"))
}
}
}
}

0 comments on commit 9a4e5b5

Please sign in to comment.