Skip to content

Commit

Permalink
Feat: Support array_intersect function (#1271)
Browse files Browse the repository at this point in the history
* Feat: Support array_intersect

* Address review comment
  • Loading branch information
erenavsarogullari authored Jan 21, 2025
1 parent c3a552f commit 824ad1a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
17 changes: 17 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ use datafusion::{
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_functions_nested::set_ops::array_intersect_udf;
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};

use crate::execution::shuffle::CompressionCodec;
Expand Down Expand Up @@ -774,6 +775,22 @@ impl PhysicalPlanner {

Ok(Arc::new(case_expr))
}
ExprStruct::ArrayIntersect(expr) => {
let left_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![Arc::clone(&left_expr), right_expr];
let datafusion_array_intersect = array_intersect_udf();
let return_type = left_expr.data_type(&input_schema)?;
let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
"array_intersect",
datafusion_array_intersect,
args,
return_type,
));
Ok(array_intersect_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ message Expr {
ArrayInsert array_insert = 59;
BinaryExpr array_contains = 60;
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expr.children(1),
inputs,
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
case _ if expr.prettyName == "array_intersect" =>
createBinaryExpr(
expr.children(0),
expr.children(1),
inputs,
(builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
16 changes: 16 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2675,4 +2675,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("array_intersect") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
checkSparkAnswerAndOperator(
sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1"))
checkSparkAnswerAndOperator(
sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1"))
checkSparkAnswerAndOperator(sql("SELECT array_intersect(array(_18), array(_19)) from t1"))
}
}
}

}

0 comments on commit 824ad1a

Please sign in to comment.