From 824ad1a6d0ea0325e1339347db1bc0a50d3122ed Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Tue, 21 Jan 2025 14:45:33 -0800 Subject: [PATCH] Feat: Support array_intersect function (#1271) * Feat: Support array_intersect * Address review comment --- native/core/src/execution/planner.rs | 17 +++++++++++++++++ native/proto/src/proto/expr.proto | 1 + .../org/apache/comet/serde/QueryPlanSerde.scala | 6 ++++++ .../org/apache/comet/CometExpressionSuite.scala | 16 ++++++++++++++++ 4 files changed, 40 insertions(+) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ccf7e31e4e..c7df503a7a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -67,6 +67,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr}; use datafusion_functions_nested::concat::ArrayAppend; use datafusion_functions_nested::remove::array_remove_all_udf; +use datafusion_functions_nested::set_ops::array_intersect_udf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use crate::execution::shuffle::CompressionCodec; @@ -774,6 +775,22 @@ impl PhysicalPlanner { Ok(Arc::new(case_expr)) } + ExprStruct::ArrayIntersect(expr) => { + let left_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![Arc::clone(&left_expr), right_expr]; + let datafusion_array_intersect = array_intersect_udf(); + let return_type = left_expr.data_type(&input_schema)?; + let array_intersect_expr = Arc::new(ScalarFunctionExpr::new( + "array_intersect", + datafusion_array_intersect, + args, + return_type, + )); + Ok(array_intersect_expr) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8e3bc60b0f..0b7d24d9f6 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -86,6 +86,7 @@ message Expr { ArrayInsert array_insert = 59; BinaryExpr array_contains = 60; BinaryExpr array_remove = 61; + BinaryExpr array_intersect = 62; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c37abf3c9..124e3be85b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2302,6 +2302,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim expr.children(1), inputs, (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) + case _ if expr.prettyName == "array_intersect" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index b4724b5416..b59830b268 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2675,4 +2675,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("array_intersect") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_intersect(array(_18), array(_19)) from t1")) + } + } + } + }