diff --git a/docs/querying/math-expr.md b/docs/querying/math-expr.md index d5255544a03e..38ced649c06c 100644 --- a/docs/querying/math-expr.md +++ b/docs/querying/math-expr.md @@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for each function. | array_ordinal(arr,long) | returns the array element at the 1 based index supplied, or null for an out of range index | | array_contains(arr,expr) | returns 1 if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array, else 0 | | array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in common, else 0 | -| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 | +| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 if the expr is non-null, or null if the expr is null | | array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. | | array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. | | array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array | diff --git a/docs/querying/sql-array-functions.md b/docs/querying/sql-array-functions.md index ab84c664dee7..7b0f2112b6f7 100644 --- a/docs/querying/sql-array-functions.md +++ b/docs/querying/sql-array-functions.md @@ -52,9 +52,9 @@ The following table describes array functions. To learn more about array aggrega |`ARRAY_LENGTH(arr)`|Returns length of the array expression.| |`ARRAY_OFFSET(arr, long)`|Returns the array element at the 0-based index supplied, or null for an out of range index.| |`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index supplied, or null for an out of range index.| -|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.| -|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements in common, else 0.| -| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in `arr`. else 0.| +|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns true if `arr` contains all elements of `expr`. Otherwise returns false.| +|`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any elements in common, else false.| +|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN` if the scalar `expr` is `NULL`.| |`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).| |`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).| |`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the resulting array type determined by the type of `arr`.| diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md index 093e7ce60fde..883f3b209ace 100644 --- a/docs/querying/sql-functions.md +++ b/docs/querying/sql-functions.md @@ -156,7 +156,7 @@ Concatenates array inputs into a single array. **Function type:** [Array](./sql-array-functions.md) -If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0. +If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns false. ## ARRAY_LENGTH @@ -204,7 +204,7 @@ Returns the 1-based index of the first occurrence of `expr` in the array. If no **Function type:** [Array](./sql-array-functions.md) -Returns 1 if `arr1` and `arr2` have any elements in common, else 0.| +Returns true if `arr1` and `arr2` have any elements in common, else false. ## SCALAR_IN_ARRAY @@ -212,7 +212,10 @@ Returns 1 if `arr1` and `arr2` have any elements in common, else 0.| **Function type:** [Array](./sql-array-functions.md) -Returns 1 if the scalar `expr` is present in `arr`, else 0.| +Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or +`UNKNOWN` if the scalar `expr` is `NULL`. + +Returns `UNKNOWN` if `arr` is `NULL`. ## ARRAY_PREPEND diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index aa54409e132e..48bc0570aaa3 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -45,6 +45,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -3724,8 +3725,11 @@ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) } } - class ArrayScalarInFunction extends ArrayScalarFunction + class ScalarInArrayFunction extends ArrayScalarFunction { + private static final int SCALAR_ARG = 0; + private static final int ARRAY_ARG = 1; + @Override public String name() { @@ -3742,23 +3746,105 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) { - return args.get(0); + return args.get(SCALAR_ARG); } @Override Expr getArrayArgument(List args) { - return args.get(1); + return args.get(ARRAY_ARG); } @Override - ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + ExprEval doApply(ExprEval arrayEval, ExprEval scalarEval) { - final Object[] array = arrayExpr.castTo(scalarExpr.asArrayType()).asArray(); + final Object[] array = arrayEval.asArray(); if (array == null) { return ExprEval.ofLong(null); } - return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value())); + + if (scalarEval.value() == null) { + return Arrays.asList(array).contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null); + } + + final ExpressionType matchType = arrayEval.elementType(); + final ExprEval scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType); + + if (scalarEvalForComparison == null) { + return ExprEval.ofLongBoolean(false); + } else { + return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarEvalForComparison.value())); + } + } + + @Override + public Function asSingleThreaded(List args, Expr.InputBindingInspector inspector) + { + if (args.get(ARRAY_ARG).isLiteral()) { + final ExpressionType lhsType = args.get(SCALAR_ARG).getOutputType(inspector); + if (lhsType == null) { + return this; + } + + final ExprEval arrayEval = args.get(ARRAY_ARG).eval(InputBindings.nilBindings()); + final Object[] arrayValues = arrayEval.asArray(); + + if (arrayValues == null) { + return WithNullArray.INSTANCE; + } else { + final Set matchValues = new HashSet<>(Arrays.asList(arrayValues)); + final ExpressionType matchType = arrayEval.elementType(); + return new WithConstantArray(matchValues, matchType); + } + } + return this; + } + + /** + * Specialization of {@link ScalarInArrayFunction} for null {@link #ARRAY_ARG}. + */ + private static final class WithNullArray extends ScalarInArrayFunction + { + private static final WithNullArray INSTANCE = new WithNullArray(); + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + return ExprEval.of(null); + } + } + + /** + * Specialization of {@link ScalarInArrayFunction} for constant, non-null {@link #ARRAY_ARG}. + */ + private static final class WithConstantArray extends ScalarInArrayFunction + { + private final Set matchValues; + private final ExpressionType matchType; + + public WithConstantArray(Set matchValues, ExpressionType matchType) + { + this.matchValues = Preconditions.checkNotNull(matchValues, "matchValues"); + this.matchType = Preconditions.checkNotNull(matchType, "matchType"); + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval scalarEval = args.get(SCALAR_ARG).eval(bindings); + + if (scalarEval.value() == null) { + return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null); + } + + final ExprEval scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType); + + if (scalarEvalForComparison == null) { + return ExprEval.ofLongBoolean(false); + } else { + return ExprEval.ofLongBoolean(matchValues.contains(scalarEvalForComparison.value())); + } + } } } diff --git a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java index da81a556b0b7..d6143fd1fa15 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -373,12 +373,15 @@ public void testArrayOrdinalOf() public void testScalarInArray() { assertExpr("scalar_in_array(2, [1, 2, 3])", 1L); + assertExpr("scalar_in_array(2.1, [1, 2, 3])", 0L); + assertExpr("scalar_in_array(2, [1.1, 2.1, 3.1])", 0L); + assertExpr("scalar_in_array(2, [1.1, 2.0, 3.1])", 1L); assertExpr("scalar_in_array(4, [1, 2, 3])", 0L); assertExpr("scalar_in_array(b, [3, 4])", 0L); assertExpr("scalar_in_array(1, null)", null); assertExpr("scalar_in_array(null, null)", null); assertExpr("scalar_in_array(null, [1, null, 2])", 1L); - assertExpr("scalar_in_array(null, [1, 2])", 0L); + assertExpr("scalar_in_array(null, [1, 2])", null); } @Test @@ -1290,6 +1293,13 @@ private void assertExpr( final Expr singleThreaded = Expr.singleThreaded(expr, bindings); Assert.assertEquals(singleThreaded.stringify(), expectedResult, singleThreaded.eval(bindings).value()); + final Expr singleThreadedNoFlatten = Expr.singleThreaded(exprNoFlatten, bindings); + Assert.assertEquals( + singleThreadedNoFlatten.stringify(), + expectedResult, + singleThreadedNoFlatten.eval(bindings).value() + ); + Assert.assertEquals(expr.stringify(), roundTrip.stringify()); Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify()); Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());