Skip to content

Commit

Permalink
Support arrays_overlap function
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari committed Jan 28, 2025
1 parent 497e40b commit 96f776e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
16 changes: 16 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use datafusion::{
prelude::SessionContext,
};
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
use datafusion_functions_nested::array_has::array_has_any_udf;
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_functions_nested::set_ops::array_intersect_udf;
Expand Down Expand Up @@ -818,6 +819,21 @@ impl PhysicalPlanner {
));
Ok(array_join_expr)
}
ExprStruct::ArraysOverlap(expr) => {
let left_array_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_array_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![Arc::clone(&left_array_expr), right_array_expr];
let datafusion_array_has_any = array_has_any_udf();
let array_has_any_expr = Arc::new(ScalarFunctionExpr::new(
"array_has_any",
datafusion_array_has_any,
args,
DataType::Boolean,
));
Ok(array_has_any_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message Expr {
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
}
}

Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
None
}
case ArraysOverlap(leftArrayExpr, rightArrayExpr) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
createBinaryExpr(
expr,
leftArrayExpr,
rightArrayExpr,
inputs,
binding,
(builder, binaryExpr) => builder.setArraysOverlap(binaryExpr))
} else {
withInfo(
expr,
s"$expr is not supported yet. To enable all incompatible casts, set " +
s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
None
}
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
21 changes: 21 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2701,4 +2701,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("arrays_overlap") {
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
checkSparkAnswerAndOperator(sql(
"SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null"))
checkSparkAnswerAndOperator(sql(
"SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null"))
checkSparkAnswerAndOperator(sql(
"SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null"))
checkSparkAnswerAndOperator(spark.sql(
"SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1"));
}
}
}
}

}

0 comments on commit 96f776e

Please sign in to comment.