Skip to content

Commit

Permalink
SCALAR_IN_ARRAY: Optimization and behavioral follow-ups. (#16311)
Browse files Browse the repository at this point in the history
* Four changes to scalar_in_array as follow-ups to #16306:

1) Align behavior for `null` scalars to the behavior of the native `in` and `inType` filters: return `true` if the array itself contains null, else return `null`.

2) Rename the class to more closely match the function name.

3) Add a specialization for constant arrays, where we build a `HashSet`.

4) Use `castForEqualityComparison` to properly handle cross-type comparisons.
   Additional tests verify comparisons between LONG and DOUBLE are now
   handled properly.

* Fix spelling.

* Adjustments from review.
  • Loading branch information
gianm authored Apr 26, 2024
1 parent 9cd1890 commit db82adc
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/querying/math-expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for each function.
| array_ordinal(arr,long) | returns the array element at the 1 based index supplied, or null for an out of range index |
| array_contains(arr,expr) | returns 1 if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array, else 0 |
| array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in common, else 0 |
| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 |
| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 if the expr is non-null, or null if the expr is null |
| array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. |
| array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. |
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array |
Expand Down
6 changes: 3 additions & 3 deletions docs/querying/sql-array-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ The following table describes array functions. To learn more about array aggrega
|`ARRAY_LENGTH(arr)`|Returns length of the array expression.|
|`ARRAY_OFFSET(arr, long)`|Returns the array element at the 0-based index supplied, or null for an out of range index.|
|`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index supplied, or null for an out of range index.|
|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.|
|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in `arr`. else 0.|
|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns true if `arr` contains all elements of `expr`. Otherwise returns false.|
|`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any elements in common, else false.|
|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN` if the scalar `expr` is `NULL`.|
|`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).|
|`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).|
|`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the resulting array type determined by the type of `arr`.|
Expand Down
9 changes: 6 additions & 3 deletions docs/querying/sql-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Concatenates array inputs into a single array.

**Function type:** [Array](./sql-array-functions.md)

If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.
If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns false.


## ARRAY_LENGTH
Expand Down Expand Up @@ -204,15 +204,18 @@ Returns the 1-based index of the first occurrence of `expr` in the array. If no

**Function type:** [Array](./sql-array-functions.md)

Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
Returns true if `arr1` and `arr2` have any elements in common, else false.

## SCALAR_IN_ARRAY

`SCALAR_IN_ARRAY(expr, arr)`

**Function type:** [Array](./sql-array-functions.md)

Returns 1 if the scalar `expr` is present in `arr`, else 0.|
Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or
`UNKNOWN` if the scalar `expr` is `NULL`.

Returns `UNKNOWN` if `arr` is `NULL`.

## ARRAY_PREPEND

Expand Down
98 changes: 92 additions & 6 deletions processing/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -3724,8 +3725,11 @@ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
}
}

class ArrayScalarInFunction extends ArrayScalarFunction
class ScalarInArrayFunction extends ArrayScalarFunction
{
private static final int SCALAR_ARG = 0;
private static final int ARRAY_ARG = 1;

@Override
public String name()
{
Expand All @@ -3742,23 +3746,105 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
@Override
Expr getScalarArgument(List<Expr> args)
{
return args.get(0);
return args.get(SCALAR_ARG);
}

@Override
Expr getArrayArgument(List<Expr> args)
{
return args.get(1);
return args.get(ARRAY_ARG);
}

@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
ExprEval doApply(ExprEval arrayEval, ExprEval scalarEval)
{
final Object[] array = arrayExpr.castTo(scalarExpr.asArrayType()).asArray();
final Object[] array = arrayEval.asArray();
if (array == null) {
return ExprEval.ofLong(null);
}
return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value()));

if (scalarEval.value() == null) {
return Arrays.asList(array).contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
}

final ExpressionType matchType = arrayEval.elementType();
final ExprEval<?> scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType);

if (scalarEvalForComparison == null) {
return ExprEval.ofLongBoolean(false);
} else {
return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarEvalForComparison.value()));
}
}

@Override
public Function asSingleThreaded(List<Expr> args, Expr.InputBindingInspector inspector)
{
if (args.get(ARRAY_ARG).isLiteral()) {
final ExpressionType lhsType = args.get(SCALAR_ARG).getOutputType(inspector);
if (lhsType == null) {
return this;
}

final ExprEval<?> arrayEval = args.get(ARRAY_ARG).eval(InputBindings.nilBindings());
final Object[] arrayValues = arrayEval.asArray();

if (arrayValues == null) {
return WithNullArray.INSTANCE;
} else {
final Set<Object> matchValues = new HashSet<>(Arrays.asList(arrayValues));
final ExpressionType matchType = arrayEval.elementType();
return new WithConstantArray(matchValues, matchType);
}
}
return this;
}

/**
* Specialization of {@link ScalarInArrayFunction} for null {@link #ARRAY_ARG}.
*/
private static final class WithNullArray extends ScalarInArrayFunction
{
private static final WithNullArray INSTANCE = new WithNullArray();

@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
return ExprEval.of(null);
}
}

/**
* Specialization of {@link ScalarInArrayFunction} for constant, non-null {@link #ARRAY_ARG}.
*/
private static final class WithConstantArray extends ScalarInArrayFunction
{
private final Set<Object> matchValues;
private final ExpressionType matchType;

public WithConstantArray(Set<Object> matchValues, ExpressionType matchType)
{
this.matchValues = Preconditions.checkNotNull(matchValues, "matchValues");
this.matchType = Preconditions.checkNotNull(matchType, "matchType");
}

@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval scalarEval = args.get(SCALAR_ARG).eval(bindings);

if (scalarEval.value() == null) {
return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
}

final ExprEval<?> scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType);

if (scalarEvalForComparison == null) {
return ExprEval.ofLongBoolean(false);
} else {
return ExprEval.ofLongBoolean(matchValues.contains(scalarEvalForComparison.value()));
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,15 @@ public void testArrayOrdinalOf()
public void testScalarInArray()
{
assertExpr("scalar_in_array(2, [1, 2, 3])", 1L);
assertExpr("scalar_in_array(2.1, [1, 2, 3])", 0L);
assertExpr("scalar_in_array(2, [1.1, 2.1, 3.1])", 0L);
assertExpr("scalar_in_array(2, [1.1, 2.0, 3.1])", 1L);
assertExpr("scalar_in_array(4, [1, 2, 3])", 0L);
assertExpr("scalar_in_array(b, [3, 4])", 0L);
assertExpr("scalar_in_array(1, null)", null);
assertExpr("scalar_in_array(null, null)", null);
assertExpr("scalar_in_array(null, [1, null, 2])", 1L);
assertExpr("scalar_in_array(null, [1, 2])", 0L);
assertExpr("scalar_in_array(null, [1, 2])", null);
}

@Test
Expand Down Expand Up @@ -1290,6 +1293,13 @@ private void assertExpr(
final Expr singleThreaded = Expr.singleThreaded(expr, bindings);
Assert.assertEquals(singleThreaded.stringify(), expectedResult, singleThreaded.eval(bindings).value());

final Expr singleThreadedNoFlatten = Expr.singleThreaded(exprNoFlatten, bindings);
Assert.assertEquals(
singleThreadedNoFlatten.stringify(),
expectedResult,
singleThreadedNoFlatten.eval(bindings).value()
);

Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());
Expand Down

0 comments on commit db82adc

Please sign in to comment.