Skip to content

Commit

Permalink
various fixes and improvements to vectorization fallback (#17098) (#1…
Browse files Browse the repository at this point in the history
…7142)

changes:
* add `ApplyFunction` support to vectorization fallback, allowing many of the remaining expressions to be vectorized
* add `CastToObjectVectorProcessor` so that vector engine can correctly cast any type
* add support for array and complex vector constants
* reduce number of cases which can block vectorization in expression planner to be unknown inputs (such as unknown multi-valuedness)
* fix array constructor expression, apply map expression to make actual evaluated type match the output type inference
* fix bug in array_contains where something like array_contains([null], 'hello') would return true if the array was a numeric array since the non-null string value would cast to a null numeric
* fix isNull/isNotNull to correctly handle any type of input argument
  • Loading branch information
clintropolis authored Sep 24, 2024
1 parent 0ae9988 commit cf00b4c
Show file tree
Hide file tree
Showing 37 changed files with 487 additions and 388 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;

Expand Down Expand Up @@ -135,10 +136,28 @@ ExprEval applyMap(@Nullable ExpressionType arrayType, LambdaExpr expr, Indexable
{
final int length = bindings.getLength();
Object[] out = new Object[length];
final boolean computeArrayType = arrayType == null;
ExpressionType arrayElementType = arrayType != null
? (ExpressionType) arrayType.getElementType()
: null;
final ExprEval<?>[] outEval = computeArrayType ? new ExprEval[length] : null;
for (int i = 0; i < length; i++) {

ExprEval evaluated = expr.eval(bindings.withIndex(i));
arrayType = Function.ArrayConstructorFunction.setArrayOutput(arrayType, out, i, evaluated);
final ExprEval<?> eval = expr.eval(bindings.withIndex(i));
if (computeArrayType && outEval[i].value() != null) {
arrayElementType = ExpressionTypeConversion.leastRestrictiveType(arrayElementType, eval.type());
outEval[i] = eval;
} else {
out[i] = eval.castTo(arrayElementType).value();
}
}
if (arrayElementType == null) {
arrayElementType = NullHandling.sqlCompatible() ? ExpressionType.LONG : ExpressionType.STRING;
}
if (computeArrayType) {
arrayType = ExpressionTypeFactory.getInstance().ofArray(arrayElementType);
for (int i = 0; i < length; i++) {
out[i] = outEval[i].castTo(arrayElementType).value();
}
}
return ExprEval.ofArray(arrayType, out);
}
Expand Down Expand Up @@ -237,7 +256,7 @@ public ExprEval apply(LambdaExpr lambdaExpr, List<Expr> argsExpr, Expr.ObjectBin
List<List<Object>> product = CartesianList.create(arrayInputs);
CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(elementType, product, lambdaExpr, bindings);
ExpressionType lambdaType = lambdaExpr.getOutputType(lambdaBinding);
return applyMap(ExpressionType.asArrayType(lambdaType), lambdaExpr, lambdaBinding);
return applyMap(lambdaType == null ? null : ExpressionTypeFactory.getInstance().ofArray(lambdaType), lambdaExpr, lambdaBinding);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ public Expr asSingleThreaded(InputBindingInspector inspector)
return new ExprEvalBasedConstantExpr<T>(realEval());
}

@Override
public <E> ExprVectorProcessor<E> asVectorProcessor(VectorInputBindingInspector inspector)
{
return VectorProcessors.constant(value, inspector.getMaxVectorSize(), outputType);
}
/**
* Constant expression based on a concreate ExprEval.
*
Expand Down Expand Up @@ -415,7 +420,7 @@ protected ExprEval realEval()
@Override
public <T> ExprVectorProcessor<T> asVectorProcessor(VectorInputBindingInspector inspector)
{
return VectorProcessors.constant(value, inspector.getMaxVectorSize());
return VectorProcessors.constant(value, inspector.getMaxVectorSize(), ExpressionType.STRING);
}

@Override
Expand Down Expand Up @@ -459,12 +464,6 @@ protected ExprEval realEval()
return ExprEval.ofArray(outputType, value);
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}

@Override
public String stringify()
{
Expand Down Expand Up @@ -547,12 +546,6 @@ protected ExprEval realEval()
return ExprEval.ofComplex(outputType, value);
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}

@Override
public String stringify()
{
Expand Down
49 changes: 18 additions & 31 deletions processing/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -2026,7 +2026,8 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
{
return CastToTypeVectorProcessor.cast(
args.get(0).asVectorProcessor(inspector),
ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()))
ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())),
inspector.getMaxVectorSize()
);
}
}
Expand Down Expand Up @@ -3357,19 +3358,24 @@ public String name()
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
// this is copied from 'BaseMapFunction.applyMap', need to find a better way to consolidate, or construct arrays,
// or.. something...
final int length = args.size();
Object[] out = new Object[length];

ExpressionType arrayType = null;

ExpressionType arrayElementType = null;
final ExprEval[] outEval = new ExprEval[length];
for (int i = 0; i < length; i++) {
ExprEval<?> evaluated = args.get(i).eval(bindings);
arrayType = setArrayOutput(arrayType, out, i, evaluated);
outEval[i] = args.get(i).eval(bindings);
if (outEval[i].value() != null) {
arrayElementType = ExpressionTypeConversion.leastRestrictiveType(arrayElementType, outEval[i].type());
}
}

return ExprEval.ofArray(arrayType, out);
if (arrayElementType == null) {
arrayElementType = NullHandling.sqlCompatible() ? ExpressionType.LONG : ExpressionType.STRING;
}
for (int i = 0; i < length; i++) {
out[i] = outEval[i].castTo(arrayElementType).value();
}
return ExprEval.ofArray(ExpressionTypeFactory.getInstance().ofArray(arrayElementType), out);
}

@Override
Expand All @@ -3394,28 +3400,6 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
}
return type == null ? null : ExpressionTypeFactory.getInstance().ofArray(type);
}

/**
* Set an array element to the output array, checking for null if the array is numeric. If the type of the evaluated
* array element does not match the array element type, this method will attempt to call {@link ExprEval#castTo}
* to the array element type, else will set the element as is. If the type of the array is unknown, it will be
* detected and defined from the first element. Returns the type of the array, which will be identical to the input
* type, unless the input type was null.
*/
static ExpressionType setArrayOutput(@Nullable ExpressionType arrayType, Object[] out, int i, ExprEval evaluated)
{
if (arrayType == null) {
arrayType = ExpressionTypeFactory.getInstance().ofArray(evaluated.type());
}
if (arrayType.getElementType().isNumeric() && evaluated.isNumericNull()) {
out[i] = null;
} else if (!evaluated.asArrayType().equals(arrayType)) {
out[i] = evaluated.castTo((ExpressionType) arrayType.getElementType()).value();
} else {
out[i] = evaluated.value();
}
return arrayType;
}
}

class ArrayLengthFunction implements Function
Expand Down Expand Up @@ -3954,6 +3938,9 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2)));
} else {
final Object elem = rhsExpr.castTo((ExpressionType) array1Type.getElementType()).value();
if (elem == null && rhsExpr.value() != null) {
return ExprEval.ofLongBoolean(false);
}
return ExprEval.ofLongBoolean(Arrays.asList(array1).contains(elem));
}
}
Expand Down
141 changes: 13 additions & 128 deletions processing/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.druid.math.expr;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringUtils;
Expand All @@ -30,136 +29,13 @@
import javax.annotation.Nullable;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

@SuppressWarnings("unused")
final class FunctionalExpr
{
// phony class to enable maven to track the compilation of this class
}

@SuppressWarnings("ClassName")
class LambdaExpr implements Expr
{
private final ImmutableList<IdentifierExpr> args;
private final Expr expr;

LambdaExpr(List<IdentifierExpr> args, Expr expr)
{
this.args = ImmutableList.copyOf(args);
this.expr = expr;
}

@Override
public String toString()
{
return StringUtils.format("(%s -> %s)", args, expr);
}

int identifierCount()
{
return args.size();
}

@Nullable
public String getIdentifier()
{
Preconditions.checkState(args.size() < 2, "LambdaExpr has multiple arguments, use getIdentifiers");
if (args.size() == 1) {
return args.get(0).toString();
}
return null;
}

public List<String> getIdentifiers()
{
return args.stream().map(IdentifierExpr::toString).collect(Collectors.toList());
}

public List<String> stringifyIdentifiers()
{
return args.stream().map(IdentifierExpr::stringify).collect(Collectors.toList());
}

ImmutableList<IdentifierExpr> getIdentifierExprs()
{
return args;
}

public Expr getExpr()
{
return expr;
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return expr.canVectorize(inspector);
}

@Override
public <T> ExprVectorProcessor<T> asVectorProcessor(VectorInputBindingInspector inspector)
{
return expr.asVectorProcessor(inspector);
}

@Override
public ExprEval eval(ObjectBinding bindings)
{
return expr.eval(bindings);
}

@Override
public String stringify()
{
return StringUtils.format("(%s) -> %s", ARG_JOINER.join(stringifyIdentifiers()), expr.stringify());
}

@Override
public Expr visit(Shuttle shuttle)
{
List<IdentifierExpr> newArgs =
args.stream().map(arg -> (IdentifierExpr) shuttle.visit(arg)).collect(Collectors.toList());
Expr newBody = expr.visit(shuttle);
return shuttle.visit(new LambdaExpr(newArgs, newBody));
}

@Override
public BindingAnalysis analyzeInputs()
{
final Set<String> lambdaArgs = args.stream().map(IdentifierExpr::toString).collect(Collectors.toSet());
BindingAnalysis bodyDetails = expr.analyzeInputs();
return bodyDetails.removeLambdaArguments(lambdaArgs);
}

@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return expr.getOutputType(inspector);
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LambdaExpr that = (LambdaExpr) o;
return Objects.equals(args, that.args) &&
Objects.equals(expr, that.expr);
}

@Override
public int hashCode()
{
return Objects.hash(args, expr);
}
}

/**
* {@link Expr} node for a {@link Function} call. {@link FunctionExpr} has children {@link Expr} in the form of the
* list of arguments that are passed to the {@link Function} along with the {@link Expr.ObjectBinding} when it is
Expand Down Expand Up @@ -350,15 +226,24 @@ public ExprEval eval(ObjectBinding bindings)
@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return function.canVectorize(inspector, lambdaExpr, argsExpr) &&
lambdaExpr.canVectorize(inspector) &&
argsExpr.stream().allMatch(expr -> expr.canVectorize(inspector));
return canVectorizeNative(inspector) || (getOutputType(inspector) != null && inspector.canVectorize(argsExpr));
}

@Override
public <T> ExprVectorProcessor<T> asVectorProcessor(VectorInputBindingInspector inspector)
{
return function.asVectorProcessor(inspector, lambdaExpr, argsExpr);
if (canVectorizeNative(inspector)) {
return function.asVectorProcessor(inspector, lambdaExpr, argsExpr);
} else {
return FallbackVectorProcessor.create(function, lambdaExpr, argsExpr, inspector);
}
}

private boolean canVectorizeNative(InputBindingInspector inspector)
{
return function.canVectorize(inspector, lambdaExpr, argsExpr) &&
lambdaExpr.canVectorize(inspector) &&
argsExpr.stream().allMatch(expr -> expr.canVectorize(inspector));
}

@Override
Expand Down
Loading

0 comments on commit cf00b4c

Please sign in to comment.