From 823f620edea2c6bc75703358a6a5ae086c6aa04e Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 20 Sep 2023 10:44:32 -0700 Subject: [PATCH] Add IS [NOT] DISTINCT FROM to SQL and join matchers. (#14976) * Add IS [NOT] DISTINCT FROM to SQL and join matchers. Changes: 1) Add "isdistinctfrom" and "notdistinctfrom" native expressions. 2) Add "IS [NOT] DISTINCT FROM" to SQL. It uses the new native expressions when generating expressions, and is treated the same as equals and not-equals when generating native filters on literals. 3) Update join matchers to have an "includeNull" parameter that determines whether we are operating in "equals" mode or "is not distinct from" mode. * Main changes: - Add ARRAY handling to "notdistinctfrom" and "isdistinctfrom". - Include null in pushed-down filters when using "notdistinctfrom" in a join. Other changes: - Adjust join filter analyzer to more explicitly use InDimFilter's ValuesSets, relying less on remembering to get it right to avoid copies. * Remove unused "wrap" method. * Fixes. * Remove methods we do not need. * Fix bug with INPUT_REF. --- docs/querying/sql-operators.md | 2 + .../druid/msq/querykit/DataSourcePlan.java | 20 +- .../org/apache/druid/math/expr/Exprs.java | 50 ++++- .../org/apache/druid/math/expr/Function.java | 103 ++++++++++ .../druid/query/filter/InDimFilter.java | 34 +++- .../apache/druid/segment/join/Equality.java | 14 +- .../segment/join/JoinConditionAnalysis.java | 29 +-- .../apache/druid/segment/join/Joinable.java | 20 +- .../segment/join/JoinableFactoryWrapper.java | 4 +- .../join/filter/JoinFilterAnalyzer.java | 4 +- .../JoinFilterColumnCorrelationAnalysis.java | 5 +- .../join/filter/JoinFilterCorrelations.java | 5 +- .../segment/join/lookup/LookupJoinable.java | 41 ++-- .../segment/join/table/IndexedTable.java | 11 +- .../join/table/IndexedTableJoinMatcher.java | 52 ++++- .../join/table/IndexedTableJoinable.java | 31 +-- .../druid/segment/join/table/MapIndex.java | 73 +++++-- .../join/table/RowBasedIndexBuilder.java | 41 ++-- .../join/table/UniqueLongArrayIndex.java | 17 +- .../org/apache/druid/math/expr/EvalTest.java | 104 ++++++++++ .../org/apache/druid/math/expr/ExprsTest.java | 50 +++-- .../druid/query/filter/InDimFilterTest.java | 2 +- ...BaseHashJoinSegmentStorageAdapterTest.java | 19 ++ .../HashJoinSegmentStorageAdapterTest.java | 65 ++++++- .../segment/join/JoinFilterAnalyzerTest.java | 3 +- .../join/lookup/LookupJoinableTest.java | 54 ++++-- .../BroadcastSegmentIndexedTableTest.java | 12 +- .../table/FrameBasedIndexedTableTest.java | 12 +- .../table/IndexedTableJoinMatcherTest.java | 34 ++-- .../join/table/IndexedTableJoinableTest.java | 110 +++++++---- .../join/table/RowBasedIndexBuilderTest.java | 118 ++++++++++-- .../join/table/RowBasedIndexedTableTest.java | 2 +- .../NestedFieldColumnIndexSupplierTest.java | 16 +- .../sql/calcite/expression/Expressions.java | 10 + .../ArrayOverlapOperatorConversion.java | 9 +- .../calcite/planner/DruidOperatorTable.java | 2 + .../druid/sql/calcite/rule/DruidJoinRule.java | 102 ++++++---- .../sql/calcite/CalciteJoinQueryTest.java | 144 ++++++++++++++ .../druid/sql/calcite/CalciteQueryTest.java | 180 ++++++++++++++++-- .../DecoupledPlanningCalciteQueryTest.java | 2 +- 40 files changed, 1288 insertions(+), 318 deletions(-) diff --git a/docs/querying/sql-operators.md b/docs/querying/sql-operators.md index 81c441c03367..da295821ecf9 100644 --- a/docs/querying/sql-operators.md +++ b/docs/querying/sql-operators.md @@ -79,7 +79,9 @@ Also see the [CONCAT function](sql-scalar.md#string-functions). |Operator|Description| |--------|-----------| |`x = y` |Equal to| +|`x IS NOT DISTINCT FROM y`|Equal to, considering `NULL` as a value. Never returns `NULL`.| |`x <> y`|Not equal to| +|`x IS DISTINCT FROM y`|Not equal to, considering `NULL` as a value. Never returns `NULL`.| |`x > y` |Greater than| |`x >= y`|Greater than or equal to| |`x < y` |Less than| diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index 477c3e0e1982..a826f1928e6b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -218,7 +218,7 @@ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgo JoinAlgorithm deducedJoinAlgorithm; if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) { deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; - } else if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { + } else if (canUseSortMergeJoin(joinDataSource.getConditionAnalysis())) { deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE; } else { deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; @@ -237,15 +237,21 @@ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgo } /** - * Checks if the join condition on two tables "table1" and "table2" is of the form + * Checks if the sortMerge algorithm can execute a particular join condition. + * + * Two checks: + * (1) join condition on two tables "table1" and "table2" is of the form * table1.columnA = table2.columnA && table1.columnB = table2.columnB && .... - * sortMerge algorithm can help these types of join conditions + * + * (2) join condition uses equals, not IS NOT DISTINCT FROM [sortMerge processor does not currently implement + * IS NOT DISTINCT FROM] */ - private static boolean isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis joinConditionAnalysis) + private static boolean canUseSortMergeJoin(JoinConditionAnalysis joinConditionAnalysis) { - return joinConditionAnalysis.getEquiConditions() - .stream() - .allMatch(equality -> equality.getLeftExpr().isIdentifier()); + return joinConditionAnalysis + .getEquiConditions() + .stream() + .allMatch(equality -> equality.getLeftExpr().isIdentifier() && !equality.isIncludeNull()); } /** diff --git a/processing/src/main/java/org/apache/druid/math/expr/Exprs.java b/processing/src/main/java/org/apache/druid/math/expr/Exprs.java index e8ad020fe700..72ad9fabf825 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Exprs.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Exprs.java @@ -19,11 +19,13 @@ package org.apache.druid.math.expr; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.UOE; +import org.apache.druid.segment.join.Equality; +import org.apache.druid.segment.join.JoinPrefixUtils; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Stack; @@ -79,16 +81,56 @@ public static List decomposeAnd(final Expr expr) } /** - * Decomposes an equality expr into the left- and right-hand side. + * Decomposes an equality expr into an {@link Equality}. Used by join-related code to identify equi-joins. * * @return decomposed equality, or empty if the input expr was not an equality expr */ - public static Optional> decomposeEquals(final Expr expr) + public static Optional decomposeEquals(final Expr expr, final String rightPrefix) { + final Expr lhs; + final Expr rhs; + final boolean includeNull; + if (expr instanceof BinEqExpr) { - return Optional.of(Pair.of(((BinEqExpr) expr).left, ((BinEqExpr) expr).right)); + lhs = ((BinEqExpr) expr).left; + rhs = ((BinEqExpr) expr).right; + includeNull = false; + } else if (expr instanceof FunctionExpr + && ((FunctionExpr) expr).function instanceof Function.IsNotDistinctFromFunc) { + final List args = ((FunctionExpr) expr).args; + lhs = args.get(0); + rhs = args.get(1); + includeNull = true; + } else { + return Optional.empty(); + } + + if (isLeftExprAndRightColumn(lhs, rhs, rightPrefix)) { + // rhs is a right-hand column; lhs is an expression solely of the left-hand side. + return Optional.of( + new Equality( + lhs, + Objects.requireNonNull(rhs.getBindingIfIdentifier()).substring(rightPrefix.length()), + includeNull + ) + ); + } else if (isLeftExprAndRightColumn(rhs, lhs, rightPrefix)) { + return Optional.of( + new Equality( + rhs, + Objects.requireNonNull(lhs.getBindingIfIdentifier()).substring(rightPrefix.length()), + includeNull + ) + ); } else { return Optional.empty(); } } + + private static boolean isLeftExprAndRightColumn(final Expr a, final Expr b, final String rightPrefix) + { + return a.analyzeInputs().getRequiredBindings().stream().noneMatch(c -> JoinPrefixUtils.isPrefixedBy(c, rightPrefix)) + && b.getBindingIfIdentifier() != null + && JoinPrefixUtils.isPrefixedBy(b.getBindingIfIdentifier(), rightPrefix); + } } diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 406ffac1ea7d..cabeb557792c 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -19,6 +19,7 @@ package org.apache.druid.math.expr; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; @@ -2225,6 +2226,108 @@ public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspe } } + /** + * SQL function "x IS NOT DISTINCT FROM y". Very similar to "x = y", i.e. {@link BinEqExpr}, except this function + * never returns null, and this function considers NULL as a value, so NULL itself is not-distinct-from NULL. For + * example: `x == null` returns `null` in SQL-compatible null handling mode, but `notdistinctfrom(x, null)` is + * true if `x` is null. + */ + class IsNotDistinctFromFunc implements Function + { + @Override + public String name() + { + return "notdistinctfrom"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval leftVal = args.get(0).eval(bindings); + final ExprEval rightVal = args.get(1).eval(bindings); + + if (leftVal.value() == null || rightVal.value() == null) { + return ExprEval.ofLongBoolean(leftVal.value() == null && rightVal.value() == null); + } + + // Code copied and adapted from BinaryBooleanOpExprBase and BinEqExpr. + // The code isn't shared due to differences in code structure: BinaryBooleanOpExprBase + BinEqExpr have logic + // interleaved between parent and child class, but we can't use BinaryBooleanOpExprBase as a parent here, because + // (a) this is a function, not an expr; and (b) our logic for handling and returning nulls is different from most + // binary exprs, where null in means null out. + final ExpressionType comparisonType = ExpressionTypeConversion.autoDetect(leftVal, rightVal); + switch (comparisonType.getType()) { + case STRING: + return ExprEval.ofLongBoolean(Objects.equals(leftVal.asString(), rightVal.asString())); + case LONG: + return ExprEval.ofLongBoolean(leftVal.asLong() == rightVal.asLong()); + case ARRAY: + final ExpressionType type = Preconditions.checkNotNull( + ExpressionTypeConversion.leastRestrictiveType(leftVal.type(), rightVal.type()), + "Cannot be null because ExprEval type is not nullable" + ); + return ExprEval.ofLongBoolean( + type.getNullableStrategy().compare(leftVal.castTo(type).asArray(), rightVal.castTo(type).asArray()) == 0 + ); + case DOUBLE: + default: + if (leftVal.isNumericNull() || rightVal.isNumericNull()) { + return ExprEval.ofLongBoolean(leftVal.isNumericNull() && rightVal.isNumericNull()); + } else { + return ExprEval.ofLongBoolean(leftVal.asDouble() == rightVal.asDouble()); + } + } + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 2); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + } + + /** + * SQL function "x IS DISTINCT FROM y". Very similar to "x <> y", i.e. {@link BinNeqExpr}, except this function + * never returns null. + * + * Implemented as a subclass of IsNotDistinctFromFunc to keep the code simple, and because we expect "notdistinctfrom" + * to be more common than "isdistinctfrom" in actual usage. + */ + class IsDistinctFromFunc extends IsNotDistinctFromFunc + { + @Override + public String name() + { + return "isdistinctfrom"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + return ExprEval.ofLongBoolean(!super.apply(args, bindings).asBoolean()); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 2); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + } + /** * SQL function "IS NOT FALSE". Different from "IS TRUE" in that it returns true for NULL as well. */ diff --git a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java index 47c3d78a237b..fcbd6aa49605 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java @@ -659,9 +659,9 @@ public ValuesSet() /** * Create a ValuesSet from another Collection. The Collection will be reused if it is a {@link SortedSet} with - * an appropriate comparator. + * the {@link Comparators#naturalNullsFirst()} comparator. */ - public ValuesSet(final Collection values) + private ValuesSet(final Collection values) { if (values instanceof SortedSet && Comparators.naturalNullsFirst() .equals(((SortedSet) values).comparator())) { @@ -672,6 +672,36 @@ public ValuesSet(final Collection values) } } + /** + * Creates an empty ValuesSet. + */ + public static ValuesSet create() + { + return new ValuesSet(new TreeSet<>(Comparators.naturalNullsFirst())); + } + + /** + * Creates a ValuesSet wrapping the provided single value. + * + * @throws IllegalStateException if the provided collection cannot be wrapped since it has the wrong comparator + */ + public static ValuesSet of(@Nullable final String value) + { + final ValuesSet retVal = ValuesSet.create(); + retVal.add(value); + return retVal; + } + + /** + * Creates a ValuesSet copying the provided collection. + */ + public static ValuesSet copyOf(final Collection values) + { + final TreeSet copyOfValues = new TreeSet<>(Comparators.naturalNullsFirst()); + copyOfValues.addAll(values); + return new ValuesSet(copyOfValues); + } + public SortedSet toUtf8() { final TreeSet valuesUtf8 = new TreeSet<>(ByteBufferUtils.utf8Comparator()); diff --git a/processing/src/main/java/org/apache/druid/segment/join/Equality.java b/processing/src/main/java/org/apache/druid/segment/join/Equality.java index 6b839c1f0dc8..3e1e4ea31492 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/Equality.java +++ b/processing/src/main/java/org/apache/druid/segment/join/Equality.java @@ -32,11 +32,13 @@ public class Equality { private final Expr leftExpr; private final String rightColumn; + private final boolean includeNull; - public Equality(final Expr leftExpr, final String rightColumn) + public Equality(final Expr leftExpr, final String rightColumn, final boolean includeNull) { this.leftExpr = leftExpr; this.rightColumn = rightColumn; + this.includeNull = includeNull; } public Expr getLeftExpr() @@ -49,12 +51,22 @@ public String getRightColumn() return rightColumn; } + /** + * Whether null is treated as a value that can be equal to itself. True for conditions using "IS NOT DISTINCT FROM", + * false for conditions using regular equals. + */ + public boolean isIncludeNull() + { + return includeNull; + } + @Override public String toString() { return "Equality{" + "leftExpr=" + leftExpr + ", rightColumn='" + rightColumn + '\'' + + ", includeNull=" + includeNull + '}'; } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java index 2a33da22d131..77d474720201 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java @@ -20,7 +20,6 @@ package org.apache.druid.segment.join; import com.google.common.base.Preconditions; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.Exprs; @@ -121,40 +120,18 @@ public static JoinConditionAnalysis forExpression( final List exprs = Exprs.decomposeAnd(conditionExpr); for (Expr childExpr : exprs) { - final Optional> maybeDecomposed = Exprs.decomposeEquals(childExpr); + final Optional maybeEquality = Exprs.decomposeEquals(childExpr, rightPrefix); - if (!maybeDecomposed.isPresent()) { + if (!maybeEquality.isPresent()) { nonEquiConditions.add(childExpr); } else { - final Pair decomposed = maybeDecomposed.get(); - final Expr lhs = Objects.requireNonNull(decomposed.lhs); - final Expr rhs = Objects.requireNonNull(decomposed.rhs); - - if (isLeftExprAndRightColumn(lhs, rhs, rightPrefix)) { - // rhs is a right-hand column; lhs is an expression solely of the left-hand side. - equiConditions.add( - new Equality(lhs, Objects.requireNonNull(rhs.getBindingIfIdentifier()).substring(rightPrefix.length())) - ); - } else if (isLeftExprAndRightColumn(rhs, lhs, rightPrefix)) { - equiConditions.add( - new Equality(rhs, Objects.requireNonNull(lhs.getBindingIfIdentifier()).substring(rightPrefix.length())) - ); - } else { - nonEquiConditions.add(childExpr); - } + equiConditions.add(maybeEquality.get()); } } return new JoinConditionAnalysis(condition, rightPrefix, equiConditions, nonEquiConditions); } - private static boolean isLeftExprAndRightColumn(final Expr a, final Expr b, final String rightPrefix) - { - return a.analyzeInputs().getRequiredBindings().stream().noneMatch(c -> JoinPrefixUtils.isPrefixedBy(c, rightPrefix)) - && b.getBindingIfIdentifier() != null - && JoinPrefixUtils.isPrefixedBy(b.getBindingIfIdentifier(), rightPrefix); - } - /** * Return the condition expression. */ diff --git a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java index b4f35fbc4b31..dab102d44932 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java @@ -20,6 +20,7 @@ package org.apache.druid.segment.join; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ReferenceCountedObject; import org.apache.druid.segment.column.ColumnCapabilities; @@ -86,26 +87,29 @@ JoinMatcher makeJoinMatcher( ); /** - * Returns all non-null values from a particular column along with a flag to tell if they are all unique in the column. - * If the non-null values are greater than "maxNumValues" or if the column doesn't exists or doesn't supports this + * Returns all matchable values from a particular column along with a flag to tell if they are all unique in the column. + * If the matchable values are greater than "maxNumValues" or if the column doesn't exists or doesn't supports this * operation, returns an object with empty set for column values and false for uniqueness flag. - * The uniqueness flag will only be true if we've collected all non-null values in the column and found that they're + * The uniqueness flag will only be true if we've collected all matchable values in the column and found that they're * all unique. In all other cases it will be false. * - * The returned set may be passed to {@link org.apache.druid.query.filter.InDimFilter}. For efficiency, + * The returned set may be passed to {@link InDimFilter}. For efficiency, * implementations should prefer creating the returned set with * {@code new TreeSet(Comparators.naturalNullsFirst()}}. This avoids a copy in the filter's constructor. * * @param columnName name of the column - * @param maxNumValues maximum number of values to return + * @param includeNull whether null should be considered a matchable value. If true, this method returns all values + * that are present in the column. If false, this method returns all non-null values. + * @param maxNumValues maximum number of values to return. If exceeded, returns an empty set with the "allUnique" + * flag set to false. */ - ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, int maxNumValues); + ColumnValuesWithUniqueFlag getMatchableColumnValues(String columnName, boolean includeNull, int maxNumValues); /** * Searches a column from this Joinable for a particular value, finds rows that match, * and returns values of a second column for those rows. * - * The returned set may be passed to {@link org.apache.druid.query.filter.InDimFilter}. For efficiency, + * The returned set may be passed to {@link InDimFilter}. For efficiency, * implementations should prefer creating the returned set with * {@code new TreeSet(Comparators.naturalNullsFirst()}}. This avoids a copy in the filter's constructor. * @@ -121,7 +125,7 @@ JoinMatcher makeJoinMatcher( * * In case either the search or retrieval column names are not found, this will return absent. */ - Optional> getCorrelatedColumnValues( + Optional getCorrelatedColumnValues( String searchColumnName, String searchColumnValue, String retrievalColumnName, diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java index df4d14cf621c..0bebc5aab592 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java @@ -170,7 +170,9 @@ static JoinClauseToFilterConversion convertJoinToFilter( } Joinable.ColumnValuesWithUniqueFlag columnValuesWithUniqueFlag = - clause.getJoinable().getNonNullColumnValues(condition.getRightColumn(), maxNumFilterValues); + clause.getJoinable() + .getMatchableColumnValues(condition.getRightColumn(), condition.isIncludeNull(), maxNumFilterValues); + // For an empty values set, isAllUnique flag will be true only if the column had no non-null values. if (columnValuesWithUniqueFlag.getColumnValues().isEmpty()) { if (columnValuesWithUniqueFlag.isAllUnique()) { diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java index ddbaa34d4061..a4c06e79826c 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java @@ -452,7 +452,7 @@ private static JoinFilterAnalysis rewriteSelectorFilter( for (JoinFilterColumnCorrelationAnalysis correlationAnalysis : correlationAnalyses) { if (correlationAnalysis.supportsPushDown()) { - Optional> correlatedValues = correlationAnalysis.getCorrelatedValuesMap().get( + Optional correlatedValues = correlationAnalysis.getCorrelatedValuesMap().get( Pair.of(filteringColumn, filteringValue) ); @@ -460,7 +460,7 @@ private static JoinFilterAnalysis rewriteSelectorFilter( return JoinFilterAnalysis.createNoPushdownFilterAnalysis(selectorFilter); } - Set newFilterValues = correlatedValues.get(); + InDimFilter.ValuesSet newFilterValues = correlatedValues.get(); // in nothing => match nothing if (newFilterValues.isEmpty()) { return new JoinFilterAnalysis( diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java index 6071f404e499..8c8ec795d187 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java @@ -21,6 +21,7 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; +import org.apache.druid.query.filter.InDimFilter; import javax.annotation.Nonnull; import java.util.ArrayList; @@ -43,7 +44,7 @@ public class JoinFilterColumnCorrelationAnalysis private final String joinColumn; @Nonnull private final List baseColumns; @Nonnull private final List baseExpressions; - private final Map, Optional>> correlatedValuesMap; + private final Map, Optional> correlatedValuesMap; public JoinFilterColumnCorrelationAnalysis( String joinColumn, @@ -75,7 +76,7 @@ public List getBaseExpressions() return baseExpressions; } - public Map, Optional>> getCorrelatedValuesMap() + public Map, Optional> getCorrelatedValuesMap() { return correlatedValuesMap; } diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java index ed9fe0756251..39c5188984ef 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java @@ -22,6 +22,7 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; import org.apache.druid.query.filter.Filter; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.join.Equality; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinableClause; @@ -155,7 +156,7 @@ public static JoinFilterCorrelations computeJoinFilterCorrelations( correlationForPrefix.getValue().getCorrelatedValuesMap().computeIfAbsent( Pair.of(rhsRewriteCandidate.getRhsColumn(), rhsRewriteCandidate.getValueForRewrite()), (rhsVal) -> { - Optional> correlatedValues = getCorrelatedValuesForPushDown( + Optional correlatedValues = getCorrelatedValuesForPushDown( rhsRewriteCandidate.getRhsColumn(), rhsRewriteCandidate.getValueForRewrite(), correlationForPrefix.getValue().getJoinColumn(), @@ -244,7 +245,7 @@ private static List eliminateCorrelationDup * @return A list of values of the correlatedJoinColumn that appear in rows where filterColumn = filterValue * Returns absent if we cannot determine the correlated values. */ - private static Optional> getCorrelatedValuesForPushDown( + private static Optional getCorrelatedValuesForPushDown( String filterColumn, String filterValue, String correlatedJoinColumn, diff --git a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java index 813d412735c4..d74817227b07 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java @@ -24,6 +24,7 @@ import com.google.common.collect.Sets; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.column.ColumnCapabilities; @@ -96,32 +97,38 @@ public JoinMatcher makeJoinMatcher( } @Override - public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, int maxNumValues) + public ColumnValuesWithUniqueFlag getMatchableColumnValues(String columnName, boolean includeNull, int maxNumValues) { if (LookupColumnSelectorFactory.KEY_COLUMN.equals(columnName) && extractor.canGetKeySet()) { final Set keys = extractor.keySet(); - final Set nullEquivalentValues = new HashSet<>(); - nullEquivalentValues.add(null); - if (NullHandling.replaceWithDefault()) { - nullEquivalentValues.add(NullHandling.defaultStringValue()); + final Set nonMatchingValues; + + if (includeNull) { + nonMatchingValues = Collections.emptySet(); + } else { + nonMatchingValues = new HashSet<>(); + nonMatchingValues.add(null); + if (NullHandling.replaceWithDefault()) { + nonMatchingValues.add(NullHandling.defaultStringValue()); + } } // size() of Sets.difference is slow; avoid it. - int nonNullKeys = keys.size(); + int matchingKeys = keys.size(); - for (String value : nullEquivalentValues) { + for (String value : nonMatchingValues) { if (keys.contains(value)) { - nonNullKeys--; + matchingKeys--; } } - if (nonNullKeys > maxNumValues) { + if (matchingKeys > maxNumValues) { return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); - } else if (nonNullKeys == keys.size()) { + } else if (matchingKeys == keys.size()) { return new ColumnValuesWithUniqueFlag(keys, true); } else { - return new ColumnValuesWithUniqueFlag(Sets.difference(keys, nullEquivalentValues), true); + return new ColumnValuesWithUniqueFlag(Sets.difference(keys, nonMatchingValues), true); } } else { return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); @@ -129,7 +136,7 @@ public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, int } @Override - public Optional> getCorrelatedColumnValues( + public Optional getCorrelatedColumnValues( String searchColumnName, String searchColumnValue, String retrievalColumnName, @@ -140,13 +147,13 @@ public Optional> getCorrelatedColumnValues( if (!ALL_COLUMNS.contains(searchColumnName) || !ALL_COLUMNS.contains(retrievalColumnName)) { return Optional.empty(); } - Set correlatedValues; + InDimFilter.ValuesSet correlatedValues; if (LookupColumnSelectorFactory.KEY_COLUMN.equals(searchColumnName)) { if (LookupColumnSelectorFactory.KEY_COLUMN.equals(retrievalColumnName)) { - correlatedValues = ImmutableSet.of(searchColumnValue); + correlatedValues = InDimFilter.ValuesSet.of(searchColumnValue); } else { // This should not happen in practice because the column to be joined on must be a key. - correlatedValues = Collections.singleton(extractor.apply(searchColumnValue)); + correlatedValues = InDimFilter.ValuesSet.of(extractor.apply(searchColumnValue)); } } else { if (!allowNonKeyColumnSearch) { @@ -154,11 +161,11 @@ public Optional> getCorrelatedColumnValues( } if (LookupColumnSelectorFactory.VALUE_COLUMN.equals(retrievalColumnName)) { // This should not happen in practice because the column to be joined on must be a key. - correlatedValues = ImmutableSet.of(searchColumnValue); + correlatedValues = InDimFilter.ValuesSet.of(searchColumnValue); } else { // Lookup extractor unapply only provides a list of strings, so we can't respect // maxCorrelationSetSize easily. This should be handled eventually. - correlatedValues = ImmutableSet.copyOf(extractor.unapply(searchColumnValue)); + correlatedValues = InDimFilter.ValuesSet.copyOf(extractor.unapply(searchColumnValue)); } } return Optional.of(correlatedValues); diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java index d88d81a87f40..221fa67fc0f7 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java @@ -95,6 +95,7 @@ default ColumnSelectorFactory makeColumnSelectorFactory(ReadableOffset offset, b * see {@link org.apache.druid.segment.join.JoinableFactory#computeJoinCacheKey} * * @return the byte array for cache key + * * @throws {@link IAE} if caching is not supported */ default byte[] computeCacheKey() @@ -125,8 +126,10 @@ interface Index /** * Returns whether keys are unique in this index. If this returns true, then {@link #find(Object)} will only ever * return a zero- or one-element list. + * + * @param includeNull whether null is considered a valid key */ - boolean areKeysUnique(); + boolean areKeysUnique(boolean includeNull); /** * Returns the list of row numbers corresponding to "key" in this index. @@ -134,14 +137,14 @@ interface Index * If "key" is some type other than the natural type {@link #keyType()}, it will be converted before checking * the index. */ - IntSortedSet find(Object key); + IntSortedSet find(@Nullable Object key); /** * Returns the row number corresponding to "key" in this index, or {@link #NOT_FOUND} if the key does not exist * in the index. * - * It is only valid to call this method if {@link #keyType()} is {@link ValueType#LONG} and {@link #areKeysUnique()} - * returns true. + * It is only valid to call this method if {@link #keyType()} is {@link ValueType#LONG} and + * {@link #areKeysUnique(boolean)} returns true. * * @throws UnsupportedOperationException if preconditions are not met */ diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java index a433a0ae5522..f96e4260f8e3 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java @@ -125,7 +125,7 @@ public class IndexedTableJoinMatcher implements JoinMatcher .map(pair -> makeConditionMatcher(pair.lhs, leftSelectorFactory, pair.rhs)) .collect(Collectors.toList()); - this.singleRowMatching = indexes.stream().allMatch(pair -> pair.lhs.areKeysUnique()); + this.singleRowMatching = indexes.stream().allMatch(pair -> pair.lhs.areKeysUnique(pair.rhs.isIncludeNull())); } else { throw new IAE( "Cannot build hash-join matcher on non-equi-join condition: %s", @@ -169,7 +169,7 @@ private static ConditionMatcher makeConditionMatcher( return ColumnProcessors.makeProcessor( condition.getLeftExpr(), index.keyType(), - new ConditionMatcherFactory(index), + new ConditionMatcherFactory(index, condition.isIncludeNull()), selectorFactory ); } @@ -374,21 +374,23 @@ static class ConditionMatcherFactory implements ColumnProcessorFactory (int) dimension id -> (IntSortedSet) row numbers @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") // updated via computeIfAbsent private final LruLoadingHashMap dimensionCaches; - ConditionMatcherFactory(IndexedTable.Index index) + ConditionMatcherFactory(IndexedTable.Index index, boolean includeNull) { this.keyType = index.keyType(); this.index = index; + this.includeNull = includeNull; this.dimensionCaches = new LruLoadingHashMap<>( MAX_NUM_CACHE, selector -> { int cardinality = selector.getValueCardinality(); - IntFunction loader = dimensionId -> getRowNumbers(selector, dimensionId); + IntFunction loader = dimensionId -> getRowNumbers(selector.lookupName(dimensionId)); return cardinality <= CACHE_MAX_SIZE ? new Int2IntSortedSetLookupTable(cardinality, loader) : new Int2IntSortedSetLruCache(CACHE_MAX_SIZE, loader); @@ -396,10 +398,13 @@ static class ConditionMatcherFactory implements ColumnProcessorFactory index.find(selector.getFloat()); + } else if (includeNull) { + return () -> selector.isNull() ? index.find(null) : index.find(selector.getFloat()); } else { return () -> selector.isNull() ? IntSortedSets.EMPTY_SET : index.find(selector.getFloat()); } @@ -475,6 +482,8 @@ public ConditionMatcher makeDoubleProcessor(BaseDoubleColumnValueSelector select { if (NullHandling.replaceWithDefault()) { return () -> index.find(selector.getDouble()); + } else if (includeNull) { + return () -> selector.isNull() ? index.find(null) : index.find(selector.getDouble()); } else { return () -> selector.isNull() ? IntSortedSets.EMPTY_SET : index.find(selector.getDouble()); } @@ -487,6 +496,8 @@ public ConditionMatcher makeLongProcessor(BaseLongColumnValueSelector selector) return makePrimitiveLongMatcher(selector); } else if (NullHandling.replaceWithDefault()) { return () -> index.find(selector.getLong()); + } else if (includeNull) { + return () -> selector.isNull() ? index.find(null) : index.find(selector.getLong()); } else { return () -> selector.isNull() ? IntSortedSets.EMPTY_SET : index.find(selector.getLong()); } @@ -543,6 +554,27 @@ public IntSortedSet match() return index.find(selector.getLong()); } }; + } else if (includeNull) { + return new ConditionMatcher() + { + @Override + public int matchSingleRow() + { + if (selector.isNull()) { + final IntSortedSet rowNumbers = index.find(null); + + return rowNumbers == null ? NO_CONDITION_MATCH : rowNumbers.firstInt(); + } else { + return index.findUniqueLong(selector.getLong()); + } + } + + @Override + public IntSortedSet match() + { + return selector.isNull() ? index.find(null) : index.find(selector.getLong()); + } + }; } else { return new ConditionMatcher() { diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java index cf7ced874360..4e9c5b5b3524 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java @@ -23,8 +23,8 @@ import it.unimi.dsi.fastutil.ints.IntBidirectionalIterator; import it.unimi.dsi.fastutil.ints.IntSortedSet; import org.apache.druid.common.config.NullHandling; -import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnCapabilities; @@ -38,8 +38,6 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; -import java.util.TreeSet; public class IndexedTableJoinable implements Joinable { @@ -94,35 +92,34 @@ public JoinMatcher makeJoinMatcher( } @Override - public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, final int maxNumValues) + public ColumnValuesWithUniqueFlag getMatchableColumnValues(String columnName, boolean includeNull, int maxNumValues) { final int columnPosition = table.rowSignature().indexOf(columnName); + final InDimFilter.ValuesSet matchableValues = InDimFilter.ValuesSet.create(); if (columnPosition < 0) { - return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); + return new ColumnValuesWithUniqueFlag(matchableValues /* empty set */, false); } try (final IndexedTable.Reader reader = table.columnReader(columnPosition)) { - // Use a SortedSet so InDimFilter doesn't need to create its own - final Set allValues = createValuesSet(); boolean allUnique = true; for (int i = 0; i < table.numRows(); i++) { final String s = DimensionHandlerUtils.convertObjectToString(reader.read(i)); - if (!NullHandling.isNullOrEquivalent(s)) { - if (!allValues.add(s)) { + if (includeNull || !NullHandling.isNullOrEquivalent(s)) { + if (!matchableValues.add(s)) { // Duplicate found allUnique = false; } - if (allValues.size() > maxNumValues) { + if (matchableValues.size() > maxNumValues) { return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); } } } - return new ColumnValuesWithUniqueFlag(allValues, allUnique); + return new ColumnValuesWithUniqueFlag(matchableValues, allUnique); } catch (IOException e) { throw new RuntimeException(e); @@ -130,7 +127,7 @@ public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, fina } @Override - public Optional> getCorrelatedColumnValues( + public Optional getCorrelatedColumnValues( String searchColumnName, String searchColumnValue, String retrievalColumnName, @@ -145,7 +142,7 @@ public Optional> getCorrelatedColumnValues( return Optional.empty(); } try (final Closer closer = Closer.create()) { - Set correlatedValues = createValuesSet(); + InDimFilter.ValuesSet correlatedValues = InDimFilter.ValuesSet.create(); if (table.keyColumns().contains(searchColumnName)) { IndexedTable.Index index = table.columnIndex(filterColumnPosition); IndexedTable.Reader reader = table.columnReader(correlatedColumnPosition); @@ -195,12 +192,4 @@ public Optional acquireReferences() { return table.acquireReferences(); } - - /** - * Create a Set that InDimFilter will accept without incurring a copy. - */ - private static Set createValuesSet() - { - return new TreeSet<>(Comparators.naturalNullsFirst()); - } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java b/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java index 25464deffb4f..973d9951a0fc 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java @@ -26,6 +26,7 @@ import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnType; +import javax.annotation.Nullable; import java.util.Map; /** @@ -33,25 +34,53 @@ */ public class MapIndex implements IndexedTable.Index { + /** + * Type of keys in {@link #index}. + */ private final ColumnType keyType; + + /** + * Index of all nonnull keys -> rows with those keys. + */ private final Map index; - private final boolean keysUnique; + + /** + * Rows containing a null key. + */ + @Nullable + private final IntSortedSet nullIndex; + + /** + * Whether nonnull keys are unique, i.e. everything in {@link #index} has exactly 1 element. + */ + private final boolean nonNullKeysUnique; + + /** + * Whether {@link #index} is a {@link Long2ObjectMap}. + */ private final boolean isLong2ObjectMap; /** * Creates a new instance based on a particular map. * - * @param keyType type of keys in "index" - * @param index a map of keys to matching row numbers - * @param keysUnique whether the keys are unique (if true: all IntLists in the index must be exactly 1 element) + * @param keyType type of keys in "index" + * @param index a map of keys to matching row numbers + * @param nonNullKeysUnique whether nonnull keys are unique (if true: all IntLists in the index must be exactly 1 + * element, except possibly the one corresponding to null) * * @see RowBasedIndexBuilder#build() the main caller */ - MapIndex(final ColumnType keyType, final Map index, final boolean keysUnique) + MapIndex( + final ColumnType keyType, + final Map index, + final IntSortedSet nullIndex, + final boolean nonNullKeysUnique + ) { this.keyType = Preconditions.checkNotNull(keyType, "keyType"); this.index = Preconditions.checkNotNull(index, "index"); - this.keysUnique = keysUnique; + this.nullIndex = nullIndex; + this.nonNullKeysUnique = nonNullKeysUnique; this.isLong2ObjectMap = index instanceof Long2ObjectMap; } @@ -62,23 +91,35 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(final boolean includeNull) { - return keysUnique; + if (includeNull) { + return nonNullKeysUnique && find(null).size() < 2; + } else { + return nonNullKeysUnique; + } } @Override - public IntSortedSet find(Object key) + public IntSortedSet find(@Nullable Object key) { - final Object convertedKey = DimensionHandlerUtils.convertObjectToType(key, keyType, false); + final IntSortedSet found; + + if (key == null) { + found = nullIndex; + } else { + final Object convertedKey = DimensionHandlerUtils.convertObjectToType(key, keyType, false); - if (convertedKey != null) { - final IntSortedSet found = index.get(convertedKey); - if (found != null) { - return found; + if (convertedKey != null) { + found = index.get(convertedKey); } else { - return IntSortedSets.EMPTY_SET; + // Don't look up null in the index, since this convertedKey is null because it's a failed cast, not a true null. + found = null; } + } + + if (found != null) { + return found; } else { return IntSortedSets.EMPTY_SET; } @@ -87,7 +128,7 @@ public IntSortedSet find(Object key) @Override public int findUniqueLong(long key) { - if (isLong2ObjectMap && keysUnique) { + if (isLong2ObjectMap && nonNullKeysUnique) { final IntSortedSet rows = ((Long2ObjectMap) (Map) index).get(key); assert rows == null || rows.size() == 1; return rows != null ? rows.firstInt() : NOT_FOUND; diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java b/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java index dc4b618dadae..5574f607e83d 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java @@ -50,9 +50,10 @@ public class RowBasedIndexBuilder private static final long INT_ARRAY_SMALL_SIZE_OK = 250_000; private int currentRow = 0; - private int nullKeys = 0; + private int nonNullKeys = 0; private final ColumnType keyType; private final Map index; + private IntSortedSet nullIndex; private long minLongKey = Long.MAX_VALUE; private long maxLongKey = Long.MIN_VALUE; @@ -79,22 +80,30 @@ public RowBasedIndexBuilder(ColumnType keyType) */ public RowBasedIndexBuilder add(@Nullable final Object key) { - final Object castKey = DimensionHandlerUtils.convertObjectToType(key, keyType); + if (key == null) { + // Use "nullIndex" instead of "index" because "index" may be specialized as Long2ObjectMap, which cannot + // accept null keys. + if (nullIndex == null) { + nullIndex = new IntAVLTreeSet(); + } - if (castKey != null) { - final IntSortedSet rowNums = index.computeIfAbsent(castKey, k -> new IntAVLTreeSet()); - rowNums.add(currentRow); + nullIndex.add(currentRow); + } else { + final Object castKey = DimensionHandlerUtils.convertObjectToType(key, keyType); - // Track min, max long value so we can decide later on if it's appropriate to use an array-backed implementation. - if (keyType.is(ValueType.LONG) && (long) castKey < minLongKey) { - minLongKey = (long) castKey; - } + if (castKey != null) { + index.computeIfAbsent(castKey, k -> new IntAVLTreeSet()).add(currentRow); + nonNullKeys++; - if (keyType.is(ValueType.LONG) && (long) castKey > maxLongKey) { - maxLongKey = (long) castKey; + // Track min, max long value so we can decide later on if it's appropriate to use an array-backed implementation. + if (keyType.is(ValueType.LONG) && (long) castKey < minLongKey) { + minLongKey = (long) castKey; + } + + if (keyType.is(ValueType.LONG) && (long) castKey > maxLongKey) { + maxLongKey = (long) castKey; + } } - } else { - nullKeys++; } currentRow++; @@ -107,9 +116,9 @@ public RowBasedIndexBuilder add(@Nullable final Object key) */ public IndexedTable.Index build() { - final boolean keysUnique = index.size() == currentRow - nullKeys; + final boolean nonNullKeysUnique = index.size() == nonNullKeys; - if (keyType.is(ValueType.LONG) && keysUnique && index.size() > 0) { + if (keyType.is(ValueType.LONG) && nonNullKeysUnique && !index.isEmpty() && nullIndex == null) { // May be a good candidate for UniqueLongArrayIndex. Check the range of values as compared to min and max. long range; @@ -155,6 +164,6 @@ public IndexedTable.Index build() } } - return new MapIndex(keyType, index, keysUnique); + return new MapIndex(keyType, index, nullIndex, nonNullKeysUnique); } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java b/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java index 5c5fd959de33..034ff03f9001 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java @@ -24,13 +24,19 @@ import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnType; +import javax.annotation.Nullable; + /** * An {@link IndexedTable.Index} backed by an int array. * - * This is for long-typed keys whose values all fall in a "reasonable" range. + * This is for nonnull long-typed keys whose values all fall in a "reasonable" range. Built by + * {@link RowBasedIndexBuilder#build()} when these conditions are met. */ public class UniqueLongArrayIndex implements IndexedTable.Index { + /** + * Array index is the key, value is the row number. + */ private final int[] index; private final long minKey; @@ -55,14 +61,19 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(final boolean includeNull) { return true; } @Override - public IntSortedSet find(Object key) + public IntSortedSet find(@Nullable Object key) { + if (key == null) { + // This index class never contains null keys. + return IntSortedSets.EMPTY_SET; + } + final Long longKey = DimensionHandlerUtils.convertObjectToLong(key); if (longKey != null) { diff --git a/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java b/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java index c49959aa1ae3..2f68840955fa 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -99,6 +99,11 @@ public void testDoubleEval() Assert.assertTrue(evalDouble("2.0 == 2.0", bindings) > 0.0); Assert.assertTrue(evalDouble("2.0 != 1.0", bindings) > 0.0); + Assert.assertEquals(1L, evalLong("notdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(1L, evalLong("isdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("notdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("isdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(0L, evalLong("istrue(0.0)", bindings)); Assert.assertEquals(1L, evalLong("isfalse(0.0)", bindings)); Assert.assertEquals(1L, evalLong("nottrue(0.0)", bindings)); @@ -131,6 +136,11 @@ public void testDoubleEval() Assert.assertEquals(1L, evalLong("2.0 == 2.0", bindings)); Assert.assertEquals(1L, evalLong("2.0 != 1.0", bindings)); + Assert.assertEquals(1L, evalLong("notdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(1L, evalLong("isdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("notdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("isdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(0L, evalLong("istrue(0.0)", bindings)); Assert.assertEquals(1L, evalLong("isfalse(0.0)", bindings)); Assert.assertEquals(1L, evalLong("nottrue(0.0)", bindings)); @@ -186,6 +196,8 @@ public void testLongEval() Assert.assertTrue(evalLong("9223372036854775807 <= 9223372036854775807", bindings) > 0); Assert.assertTrue(evalLong("9223372036854775807 == 9223372036854775807", bindings) > 0); Assert.assertTrue(evalLong("9223372036854775807 != 9223372036854775806", bindings) > 0); + Assert.assertTrue(evalLong("notdistinctfrom(9223372036854775807, 9223372036854775807)", bindings) > 0); + Assert.assertTrue(evalLong("isdistinctfrom(9223372036854775807, 9223372036854775806)", bindings) > 0); assertEquals(9223372036854775807L, evalLong("9223372036854775806 + 1", bindings)); assertEquals(9223372036854775806L, evalLong("9223372036854775807 - 1", bindings)); @@ -221,6 +233,92 @@ public void testLongEval() assertEquals("x", eval("nvl(if(x == 9223372036854775806, '', 'x'), 'NULL')", bindings).asString()); } + @Test + public void testIsNotDistinctFrom() + { + assertEquals( + 1L, + new Function.IsNotDistinctFromFunc() + .apply( + ImmutableList.of( + new NullLongExpr(), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 0L, + new Function.IsNotDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 1L, + new Function.IsNotDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new LongExpr(0L) + ), + InputBindings.nilBindings() + ) + .value() + ); + } + + @Test + public void testIsDistinctFrom() + { + assertEquals( + 0L, + new Function.IsDistinctFromFunc() + .apply( + ImmutableList.of( + new NullLongExpr(), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 1L, + new Function.IsDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 0L, + new Function.IsDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new LongExpr(0L) + ), + InputBindings.nilBindings() + ) + .value() + ); + } + @Test public void testIsFalse() { @@ -1151,6 +1249,8 @@ public void testArrayComparison() Assert.assertEquals(1L, eval("['a','b',null,'c'] >= stringArray", bindings).value()); Assert.assertEquals(1L, eval("['a','b',null,'c'] == stringArray", bindings).value()); Assert.assertEquals(0L, eval("['a','b',null,'c'] != stringArray", bindings).value()); + Assert.assertEquals(1L, eval("notdistinctfrom(['a','b',null,'c'], stringArray)", bindings).value()); + Assert.assertEquals(0L, eval("isdistinctfrom(['a','b',null,'c'], stringArray)", bindings).value()); Assert.assertEquals(1L, eval("['a','b',null,'c'] <= stringArray", bindings).value()); Assert.assertEquals(0L, eval("['a','b',null,'c'] < stringArray", bindings).value()); @@ -1158,6 +1258,8 @@ public void testArrayComparison() Assert.assertEquals(1L, eval("[1,null,2,3] >= longArray", bindings).value()); Assert.assertEquals(1L, eval("[1,null,2,3] == longArray", bindings).value()); Assert.assertEquals(0L, eval("[1,null,2,3] != longArray", bindings).value()); + Assert.assertEquals(1L, eval("notdistinctfrom([1,null,2,3], longArray)", bindings).value()); + Assert.assertEquals(0L, eval("isdistinctfrom([1,null,2,3], longArray)", bindings).value()); Assert.assertEquals(1L, eval("[1,null,2,3] <= longArray", bindings).value()); Assert.assertEquals(0L, eval("[1,null,2,3] < longArray", bindings).value()); @@ -1165,6 +1267,8 @@ public void testArrayComparison() Assert.assertEquals(1L, eval("[1.1,2.2,3.3,null] >= doubleArray", bindings).value()); Assert.assertEquals(1L, eval("[1.1,2.2,3.3,null] == doubleArray", bindings).value()); Assert.assertEquals(0L, eval("[1.1,2.2,3.3,null] != doubleArray", bindings).value()); + Assert.assertEquals(1L, eval("notdistinctfrom([1.1,2.2,3.3,null], doubleArray)", bindings).value()); + Assert.assertEquals(0L, eval("isdistinctfrom([1.1,2.2,3.3,null], doubleArray)", bindings).value()); Assert.assertEquals(1L, eval("[1.1,2.2,3.3,null] <= doubleArray", bindings).value()); Assert.assertEquals(0L, eval("[1.1,2.2,3.3,null] < doubleArray", bindings).value()); } diff --git a/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java b/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java index a5b6844d5a23..aeb23d257d50 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java @@ -20,8 +20,9 @@ package org.apache.druid.math.expr; import com.google.common.collect.ImmutableList; -import org.apache.druid.java.util.common.Pair; +import org.apache.druid.segment.join.Equality; import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Test; @@ -73,27 +74,52 @@ public void test_decomposeAnd_basic() @Test public void test_decomposeEquals_notAnEquals() { - final Optional> optionalPair = Exprs.decomposeEquals(new IdentifierExpr("foo")); - Assert.assertFalse(optionalPair.isPresent()); + final Optional result = Exprs.decomposeEquals(new IdentifierExpr("foo"), "j."); + Assert.assertFalse(result.isPresent()); } @Test public void test_decomposeEquals_basic() { - final Optional> optionalPair = Exprs.decomposeEquals( + final Optional result = Exprs.decomposeEquals( new BinEqExpr( "==", new IdentifierExpr("foo"), - new IdentifierExpr("bar") - ) + new IdentifierExpr("j.bar") + ), + "j." + ); + + Assert.assertTrue(result.isPresent()); + + final Equality equality = result.get(); + MatcherAssert.assertThat(equality.getLeftExpr(), CoreMatchers.instanceOf(IdentifierExpr.class)); + Assert.assertEquals("foo", ((IdentifierExpr) equality.getLeftExpr()).getIdentifier()); + Assert.assertEquals("bar", equality.getRightColumn()); + Assert.assertFalse(equality.isIncludeNull()); + } + + @Test + public void test_decomposeEquals_notDistinctFrom() + { + final Optional result = Exprs.decomposeEquals( + new FunctionExpr( + new Function.IsNotDistinctFromFunc(), + "notdistinctfrom", + ImmutableList.of( + new IdentifierExpr("foo"), + new IdentifierExpr("j.bar") + ) + ), + "j." ); - Assert.assertTrue(optionalPair.isPresent()); + Assert.assertTrue(result.isPresent()); - final Pair pair = optionalPair.get(); - Assert.assertThat(pair.lhs, CoreMatchers.instanceOf(IdentifierExpr.class)); - Assert.assertThat(pair.rhs, CoreMatchers.instanceOf(IdentifierExpr.class)); - Assert.assertEquals("foo", ((IdentifierExpr) pair.lhs).getIdentifier()); - Assert.assertEquals("bar", ((IdentifierExpr) pair.rhs).getIdentifier()); + final Equality equality = result.get(); + MatcherAssert.assertThat(equality.getLeftExpr(), CoreMatchers.instanceOf(IdentifierExpr.class)); + Assert.assertEquals("foo", ((IdentifierExpr) equality.getLeftExpr()).getIdentifier()); + Assert.assertEquals("bar", equality.getRightColumn()); + Assert.assertTrue(equality.isIncludeNull()); } } diff --git a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java index 41f2480621da..9508de7bcac6 100644 --- a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java +++ b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java @@ -92,7 +92,7 @@ public void testGetValuesWithValuesSetOfNonEmptyStringsUseTheGivenSet() @Test public void testGetValuesWithValuesSetIncludingEmptyString() { - final InDimFilter.ValuesSet values = new InDimFilter.ValuesSet(ImmutableSet.of("v1", "", "v3")); + final InDimFilter.ValuesSet values = InDimFilter.ValuesSet.copyOf(ImmutableSet.of("v1", "", "v3")); final InDimFilter filter = new InDimFilter("dim", values); if (NullHandling.replaceWithDefault()) { Assert.assertSame(values, filter.getValues()); diff --git a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java index d1de10fdaa54..e5ab7d1be8aa 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java @@ -183,6 +183,25 @@ protected JoinableClause factToRegion(final JoinType joinType) ); } + protected JoinableClause factToRegionIncludeNull(final JoinType joinType) + { + return new JoinableClause( + FACT_TO_REGION_PREFIX, + new IndexedTableJoinable(regionsTable), + joinType, + JoinConditionAnalysis.forExpression( + StringUtils.format( + "notdistinctfrom(\"%sregionIsoCode\", regionIsoCode) && " + + "notdistinctfrom(\"%scountryIsoCode\", countryIsoCode)", + FACT_TO_REGION_PREFIX, + FACT_TO_REGION_PREFIX + ), + FACT_TO_REGION_PREFIX, + ExprMacroTable.nil() + ) + ); + } + protected JoinableClause regionToCountry(final JoinType joinType) { return new JoinableClause( diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java index 5f7a10b9705e..20d032aba381 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java @@ -1269,6 +1269,69 @@ public void test_makeCursors_factToRegionToCountryLeft() ); } + @Test + public void test_makeCursors_factToRegionToCountryInnerIncludeNull() + { + List joinableClauses = ImmutableList.of( + factToRegionIncludeNull(JoinType.INNER), + regionToCountry(JoinType.LEFT) + ); + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + JoinTestHelper.verifyCursors( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + joinableClauses, + joinFilterPreAnalysis + ).makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ), + ImmutableList.of( + "page", + FACT_TO_REGION_PREFIX + "regionName", + REGION_TO_COUNTRY_PREFIX + "countryName" + ), + ImmutableList.of( + new Object[]{"Talk:Oswald Tilghman", "Nulland", null}, + new Object[]{"Rallicula", "Nulland", null}, + new Object[]{"Peremptory norm", "New South Wales", "Australia"}, + new Object[]{"Apamea abruzzorum", "Nulland", null}, + new Object[]{"Atractus flammigerus", "Nulland", null}, + new Object[]{"Agama mossambica", "Nulland", null}, + new Object[]{"Mathis Bolly", "Mexico City", "Mexico"}, + new Object[]{"유희왕 GX", "Seoul", "Republic of Korea"}, + new Object[]{"青野武", "Tōkyō", "Japan"}, + new Object[]{"Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "Chile"}, + new Object[]{"President of India", "California", "United States"}, + new Object[]{"Diskussion:Sebastian Schulz", "Hesse", "Germany"}, + new Object[]{"Saison 9 de Secret Story", "Val d'Oise", "France"}, + new Object[]{"Glasgow", "Kingston upon Hull", "United Kingdom"}, + new Object[]{"Didier Leclair", "Ontario", "Canada"}, + new Object[]{"Les Argonautes", "Quebec", "Canada"}, + new Object[]{"Otjiwarongo Airport", "California", "United States"}, + new Object[]{"Sarah Michelle Gellar", "Ontario", "Canada"}, + new Object[]{"DirecTV", "North Carolina", "United States"}, + new Object[]{"Carlo Curti", "California", "United States"}, + new Object[]{"Giusy Ferreri discography", "Provincia di Varese", "Italy"}, + new Object[]{"Roma-Bangkok", "Provincia di Varese", "Italy"}, + new Object[]{"Wendigo", "Departamento de San Salvador", "El Salvador"}, + new Object[]{"Алиса в Зазеркалье", "Finnmark Fylke", "Norway"}, + new Object[]{"Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "Ecuador"}, + new Object[]{"Old Anatolian Turkish", "Virginia", "United States"}, + new Object[]{"Cream Soda", "Ainigriv", "States United"}, + new Object[]{"History of Fourems", "Fourems Province", "Fourems"} + ) + ); + } + @Test public void test_makeCursors_factToCountryAlwaysTrue() { @@ -1850,7 +1913,7 @@ public void test_makeCursors_errorOnNonKeyBasedJoin() { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Cannot build hash-join matcher on non-key-based condition: " - + "Equality{leftExpr=x, rightColumn='countryName'}"); + + "Equality{leftExpr=x, rightColumn='countryName', includeNull=false}"); List joinableClauses = ImmutableList.of( new JoinableClause( FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX, diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java index b6fd9f4f0e0f..1b7f250f8479 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java @@ -2043,7 +2043,8 @@ public void test_filterPushDown_factToRegionThreeRHSColumnsAllDirectAndFilterOnR // filter rewrites. expectedException.expect(IAE.class); expectedException.expectMessage( - "Cannot build hash-join matcher on non-key-based condition: Equality{leftExpr=user, rightColumn='regionName'}" + "Cannot build hash-join matcher on non-key-based condition: " + + "Equality{leftExpr=user, rightColumn='regionName', includeNull=false}" ); JoinTestHelper.verifyCursors( diff --git a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java index ce1dc7fc8b49..0fa492211f07 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ValueType; @@ -36,6 +37,7 @@ import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -129,7 +131,7 @@ public void getColumnCapabilitiesForUnknownColumnShouldReturnNull() @Test public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmptySet() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( UNKNOWN_COLUMN, SEARCH_KEY_VALUE, @@ -144,7 +146,7 @@ public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmptySet() @Test public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmptySet() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_VALUE, @@ -159,7 +161,7 @@ public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmptySet @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, @@ -172,7 +174,7 @@ public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldRetur @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -185,7 +187,7 @@ public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldRet @Test public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_NULL_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -198,7 +200,7 @@ public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnSh @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonKeyColumnSearchDisabledShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -219,7 +221,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonK @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -232,7 +234,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldR @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldReturnUnAppliedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, @@ -250,7 +252,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldRet */ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLimitSetShouldHonorMaxLimit() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, @@ -263,7 +265,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLi @Test public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnShouldReturnNoCorrelatedValues() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_UNKNOWN, LookupColumnSelectorFactory.KEY_COLUMN, @@ -274,10 +276,11 @@ public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnSh } @Test - public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues( + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, + false, Integer.MAX_VALUE ); @@ -285,24 +288,41 @@ public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() + public void getMatchableColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues( + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, + false, Integer.MAX_VALUE ); Assert.assertEquals( - NullHandling.replaceWithDefault() ? ImmutableSet.of("foo", "bar") : ImmutableSet.of("foo", "bar", ""), + NullHandling.sqlCompatible() ? ImmutableSet.of("foo", "bar", "") : ImmutableSet.of("foo", "bar"), values.getColumnValues() ); } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + public void getMatchableColumnValuesWithIncludeNullIfAllUniqueForKeyColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues( + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, + true, + Integer.MAX_VALUE + ); + + Assert.assertEquals( + InDimFilter.ValuesSet.copyOf(Arrays.asList("foo", "bar", "", null)), + values.getColumnValues() + ); + } + + @Test + public void getMatchableColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + { + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( + LookupColumnSelectorFactory.KEY_COLUMN, + false, 1 ); diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java index 22ff1f3c5c27..2d929f079bfd 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java @@ -266,14 +266,10 @@ private void checkIndexAndReader(String columnName, Object[] vals, Object[] nonm // lets try a few values out for (Object val : vals) { final IntSortedSet valIndex = valueIndex.find(val); - if (val == null) { - Assert.assertEquals(0, valIndex.size()); - } else { - Assert.assertTrue(valIndex.size() > 0); - final IntBidirectionalIterator rowIterator = valIndex.iterator(); - while (rowIterator.hasNext()) { - Assert.assertEquals(val, reader.read(rowIterator.nextInt())); - } + Assert.assertTrue(valIndex.size() > 0); + final IntBidirectionalIterator rowIterator = valIndex.iterator(); + while (rowIterator.hasNext()) { + Assert.assertEquals(val, reader.read(rowIterator.nextInt())); } } for (Object val : nonmatchingVals) { diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java index a0fbd4fcc1d5..ed59f5f80652 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java @@ -325,14 +325,10 @@ private void checkIndexAndReader(String columnName, Object[] vals, Object[] nonm for (Object val : vals) { final IntSortedSet valIndex = valueIndex.find(val); - if (val == null) { - Assert.assertEquals(0, valIndex.size()); - } else { - Assert.assertTrue(valIndex.size() > 0); - final IntBidirectionalIterator rowIterator = valIndex.iterator(); - while (rowIterator.hasNext()) { - Assert.assertEquals(val, reader.read(rowIterator.nextInt())); - } + Assert.assertTrue(valIndex.size() > 0); + final IntBidirectionalIterator rowIterator = valIndex.iterator(); + while (rowIterator.hasNext()) { + Assert.assertEquals(val, reader.read(rowIterator.nextInt())); } } for (Object val : nonmatchingVals) { diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java index 57b3896648bf..b1a47355f541 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java @@ -88,7 +88,7 @@ public void tearDown() throws Exception public void testMatchToUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(ImmutableList.of(2), ImmutableList.copyOf(processor.match())); @@ -98,7 +98,7 @@ public void testMatchToUniqueLongIndex() public void testMatchSingleRowToUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(2, processor.matchSingleRow()); @@ -108,7 +108,7 @@ public void testMatchSingleRowToUniqueLongIndex() public void testMatchToNonUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(ImmutableList.of(1, 2, 3), ImmutableList.copyOf(processor.match())); @@ -118,7 +118,7 @@ public void testMatchToNonUniqueLongIndex() public void testMatchSingleRowToNonUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertThrows(UnsupportedOperationException.class, processor::matchSingleRow); @@ -128,7 +128,7 @@ public void testMatchSingleRowToNonUniqueLongIndex() public void testMatchToUniqueStringIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(ImmutableList.of(3), ImmutableList.copyOf(processor.match())); @@ -138,7 +138,7 @@ public void testMatchToUniqueStringIndex() public void testMatchSingleRowToUniqueStringIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(3, processor.matchSingleRow()); @@ -170,7 +170,7 @@ public void tearDown() throws Exception public void testMatch() { final IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeComplexProcessor(selector); @@ -182,7 +182,7 @@ public void testMatch() public void testMatchSingleRow() { final IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeComplexProcessor(selector); @@ -212,7 +212,7 @@ public void testMatchMultiValuedRowCardinalityUnknownShouldThrowException() thro .getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -233,7 +233,7 @@ public void testMatchMultiValuedRowCardinalityKnownShouldThrowException() throws Mockito.doReturn(3).when(dimensionSelector).getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -256,7 +256,7 @@ public void testMatchEmptyRowCardinalityUnknown() throws Exception .getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -278,7 +278,7 @@ public void testMatchEmptyRowCardinalityKnown() throws Exception Mockito.doReturn(0).when(dimensionSelector).getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -324,7 +324,7 @@ public void getsCorrectResultWhenSelectorCardinalityHigh() private static IndexedTableJoinMatcher.ConditionMatcher makeConditionMatcher(int valueCardinality) { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); return conditionMatcherFactory.makeDimensionProcessor( new TestDimensionSelector(KEY, valueCardinality), false @@ -503,7 +503,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return false; } @@ -533,7 +533,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return true; } @@ -567,7 +567,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return true; } @@ -603,7 +603,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return false; } diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java index 93f2d5df13eb..09ddec72b6bc 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java @@ -28,6 +28,7 @@ import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ConstantDimensionSelector; @@ -43,9 +44,9 @@ import org.junit.Before; import org.junit.Test; +import java.util.Arrays; import java.util.Collections; import java.util.Optional; -import java.util.Set; public class IndexedTableJoinableTest { @@ -89,7 +90,8 @@ public ColumnCapabilities getColumnCapabilities(String columnName) ImmutableList.of( new Object[]{"foo", 1L, 1L}, new Object[]{"bar", 2L, 1L}, - new Object[]{"baz", null, 1L} + new Object[]{"baz", null, 1L}, + new Object[]{null, 3L, 1L} ), RowSignature.builder() .add(KEY_COLUMN, ColumnType.STRING) @@ -187,7 +189,7 @@ public void makeJoinMatcherWithDimensionSelectorOnString() .makeDimensionSelector(DefaultDimensionSpec.of("str")); // getValueCardinality - Assert.assertEquals(4, selector.getValueCardinality()); + Assert.assertEquals(5, selector.getValueCardinality()); // nameLookupPossibleInAdvance Assert.assertTrue(selector.nameLookupPossibleInAdvance()); @@ -197,6 +199,7 @@ public void makeJoinMatcherWithDimensionSelectorOnString() Assert.assertEquals("bar", selector.lookupName(1)); Assert.assertEquals("baz", selector.lookupName(2)); Assert.assertNull(selector.lookupName(3)); + Assert.assertNull(selector.lookupName(4)); // lookupId Assert.assertNull(selector.idLookup()); @@ -205,13 +208,14 @@ public void makeJoinMatcherWithDimensionSelectorOnString() @Test public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmpty() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( UNKNOWN_COLUMN, "foo", VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @@ -219,13 +223,14 @@ public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmpty() @Test public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmpty() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, "foo", UNKNOWN_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @@ -233,149 +238,179 @@ public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmpty() @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_VALUE, KEY_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnAboveLimitShouldReturnEmpty() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_VALUE, KEY_COLUMN, 0, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_NULL_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.of(Collections.singleton(null)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonKeyColumnSearchDisabledShouldReturnEmpty() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, KEY_COLUMN, 10, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldReturnUnAppliedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, KEY_COLUMN, 10, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLimitSetShouldHonorMaxLimit() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, KEY_COLUMN, 0, - true); + true + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnShouldReturnNoCorrelatedValues() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_UNKNOWN, KEY_COLUMN, 10, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of()), correlatedValues); } @Test - public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnValues() + public void getMatchableColumnValuesIfAllUniqueForValueColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(VALUE_COLUMN, Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(VALUE_COLUMN, false, Integer.MAX_VALUE); - Assert.assertEquals(ImmutableSet.of("1", "2"), values.getColumnValues()); + Assert.assertEquals(ImmutableSet.of("1", "2", "3"), values.getColumnValues()); } @Test - public void getNonNullColumnValuesIfAllUniqueForNonexistentColumnShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForNonexistentColumnShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues("nonexistent", Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues("nonexistent", false, Integer.MAX_VALUE); Assert.assertEquals(ImmutableSet.of(), values.getColumnValues()); } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() + public void getMatchableColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(KEY_COLUMN, Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(KEY_COLUMN, false, Integer.MAX_VALUE); Assert.assertEquals( ImmutableSet.of("foo", "bar", "baz"), values.getColumnValues() ); + + Assert.assertTrue(values.isAllUnique()); + } + + @Test + public void getMatchableColumnValuesWithIncludeNullIfAllUniqueForKeyColumnShouldReturnValues() + { + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(KEY_COLUMN, true, Integer.MAX_VALUE); + + Assert.assertEquals( + InDimFilter.ValuesSet.copyOf(Arrays.asList(null, "foo", "bar", "baz")), + values.getColumnValues() + ); + + Assert.assertTrue(values.isAllUnique()); } @Test - public void getNonNullColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(ALL_SAME_COLUMN, Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(ALL_SAME_COLUMN, false, Integer.MAX_VALUE); Assert.assertEquals( ImmutableSet.of("1"), @@ -385,9 +420,10 @@ public void getNonNullColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(KEY_COLUMN, 1); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(KEY_COLUMN, false, 1); Assert.assertEquals(ImmutableSet.of(), values.getColumnValues()); } diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java index d399b971dafa..d6cd74f55451 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java @@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.ints.IntSortedSet; import org.apache.druid.segment.column.ColumnType; import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -35,6 +36,35 @@ public class RowBasedIndexBuilderTest @Test public void test_stringKey_uniqueKeys() + { + final RowBasedIndexBuilder builder = + new RowBasedIndexBuilder(ColumnType.STRING) + .add("abc") + .add("") + .add("1") + .add("def"); + + final IndexedTable.Index index = builder.build(); + + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + Assert.assertEquals(ColumnType.STRING, index.keyType()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertTrue(index.areKeysUnique(true)); + + Assert.assertEquals(intSet(0), index.find("abc")); + Assert.assertEquals(intSet(1), index.find("")); + Assert.assertEquals(intSet(2), index.find(1L)); + Assert.assertEquals(intSet(2), index.find("1")); + Assert.assertEquals(intSet(3), index.find("def")); + Assert.assertEquals(intSet(), index.find(null)); + Assert.assertEquals(intSet(), index.find("nonexistent")); + + expectedException.expect(UnsupportedOperationException.class); + index.findUniqueLong(0L); + } + + @Test + public void test_stringKey_uniqueKeysWithNull() { final RowBasedIndexBuilder builder = new RowBasedIndexBuilder(ColumnType.STRING) @@ -46,16 +76,48 @@ public void test_stringKey_uniqueKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.STRING, index.keyType()); - Assert.assertTrue(index.areKeysUnique()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertTrue(index.areKeysUnique(true)); Assert.assertEquals(intSet(0), index.find("abc")); Assert.assertEquals(intSet(1), index.find("")); Assert.assertEquals(intSet(3), index.find(1L)); Assert.assertEquals(intSet(3), index.find("1")); Assert.assertEquals(intSet(4), index.find("def")); - Assert.assertEquals(intSet(), index.find(null)); + Assert.assertEquals(intSet(2), index.find(null)); + Assert.assertEquals(intSet(), index.find("nonexistent")); + + expectedException.expect(UnsupportedOperationException.class); + index.findUniqueLong(0L); + } + + @Test + public void test_stringKey_duplicateNullKey() + { + final RowBasedIndexBuilder builder = + new RowBasedIndexBuilder(ColumnType.STRING) + .add("abc") + .add("") + .add(null) + .add("1") + .add(null) + .add("def"); + + final IndexedTable.Index index = builder.build(); + + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + Assert.assertEquals(ColumnType.STRING, index.keyType()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertFalse(index.areKeysUnique(true)); + + Assert.assertEquals(intSet(0), index.find("abc")); + Assert.assertEquals(intSet(1), index.find("")); + Assert.assertEquals(intSet(3), index.find(1L)); + Assert.assertEquals(intSet(3), index.find("1")); + Assert.assertEquals(intSet(5), index.find("def")); + Assert.assertEquals(intSet(2, 4), index.find(null)); Assert.assertEquals(intSet(), index.find("nonexistent")); expectedException.expect(UnsupportedOperationException.class); @@ -76,16 +138,17 @@ public void test_stringKey_duplicateKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.STRING, index.keyType()); - Assert.assertFalse(index.areKeysUnique()); + Assert.assertFalse(index.areKeysUnique(false)); + Assert.assertFalse(index.areKeysUnique(true)); Assert.assertEquals(intSet(0, 3), index.find("abc")); Assert.assertEquals(intSet(1), index.find("")); Assert.assertEquals(intSet(4), index.find(1L)); Assert.assertEquals(intSet(4), index.find("1")); Assert.assertEquals(intSet(5), index.find("def")); - Assert.assertEquals(intSet(), index.find(null)); + Assert.assertEquals(intSet(2), index.find(null)); Assert.assertEquals(intSet(), index.find("nonexistent")); expectedException.expect(UnsupportedOperationException.class); @@ -103,14 +166,44 @@ public void test_longKey_uniqueKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(UniqueLongArrayIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(UniqueLongArrayIndex.class)); Assert.assertEquals(ColumnType.LONG, index.keyType()); - Assert.assertTrue(index.areKeysUnique()); + Assert.assertTrue(index.areKeysUnique(false)); Assert.assertEquals(intSet(0), index.find(1L)); Assert.assertEquals(intSet(1), index.find(5L)); Assert.assertEquals(intSet(2), index.find(2L)); Assert.assertEquals(intSet(), index.find(3L)); + Assert.assertEquals(intSet(), index.find(null)); + + Assert.assertEquals(0, index.findUniqueLong(1L)); + Assert.assertEquals(1, index.findUniqueLong(5L)); + Assert.assertEquals(2, index.findUniqueLong(2L)); + Assert.assertEquals(IndexedTable.Index.NOT_FOUND, index.findUniqueLong(3L)); + } + + @Test + public void test_longKey_uniqueKeysWithNull() + { + final RowBasedIndexBuilder builder = + new RowBasedIndexBuilder(ColumnType.LONG) + .add(1) + .add(5) + .add(2) + .add(null); + + final IndexedTable.Index index = builder.build(); + + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + Assert.assertEquals(ColumnType.LONG, index.keyType()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertTrue(index.areKeysUnique(true)); + + Assert.assertEquals(intSet(0), index.find(1L)); + Assert.assertEquals(intSet(1), index.find(5L)); + Assert.assertEquals(intSet(2), index.find(2L)); + Assert.assertEquals(intSet(), index.find(3L)); + Assert.assertEquals(intSet(3), index.find(null)); Assert.assertEquals(0, index.findUniqueLong(1L)); Assert.assertEquals(1, index.findUniqueLong(5L)); @@ -129,14 +222,15 @@ public void test_longKey_uniqueKeys_farApart() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.LONG, index.keyType()); - Assert.assertTrue(index.areKeysUnique()); + Assert.assertTrue(index.areKeysUnique(false)); Assert.assertEquals(intSet(0), index.find(1L)); Assert.assertEquals(intSet(1), index.find(10_000_000L)); Assert.assertEquals(intSet(2), index.find(2L)); Assert.assertEquals(intSet(), index.find(3L)); + Assert.assertEquals(intSet(), index.find(null)); Assert.assertEquals(0, index.findUniqueLong(1L)); Assert.assertEquals(1, index.findUniqueLong(10_000_000L)); @@ -156,9 +250,9 @@ public void test_longKey_duplicateKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.LONG, index.keyType()); - Assert.assertFalse(index.areKeysUnique()); + Assert.assertFalse(index.areKeysUnique(false)); Assert.assertEquals(intSet(0, 2), index.find("1")); Assert.assertEquals(intSet(0, 2), index.find(1)); diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java index 789bac28ea9a..aef371bcf556 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java @@ -131,7 +131,7 @@ public void test_columnIndex_regionsRegionIsoCode() { final IndexedTable.Index index = regionsTable.columnIndex(INDEX_REGIONS_REGION_ISO_CODE); - Assert.assertEquals(ImmutableSet.of(), index.find(null)); + Assert.assertEquals(ImmutableSet.of(21), index.find(null)); Assert.assertEquals(ImmutableSet.of(0), index.find("11")); Assert.assertEquals(ImmutableSet.of(1), index.find(13)); Assert.assertEquals(ImmutableSet.of(12), index.find("QC")); diff --git a/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java b/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java index fffb9068a3ef..a691f4470a98 100644 --- a/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java +++ b/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java @@ -442,7 +442,7 @@ public void testSingleTypeStringColumnPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("b", "z")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("b", "z")) ); // 10 rows @@ -607,7 +607,7 @@ public void testSingleValueStringWithNullPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("b", "z")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("b", "z")) ); // 10 rows @@ -728,7 +728,7 @@ public void testSingleTypeLongColumnPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("1", "3")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("1", "3")) ); // 10 rows @@ -880,7 +880,7 @@ public void testSingleValueLongWithNullPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("3", "100")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("3", "100")) ); // 10 rows @@ -1025,7 +1025,7 @@ public void testSingleTypeDoubleColumnPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("1.2", "3.3", "5.0")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("1.2", "3.3", "5.0")) ); // 10 rows @@ -1162,7 +1162,7 @@ public void testSingleValueDoubleWithNullPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("1.2", "3.3")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("1.2", "3.3")) ); // 10 rows @@ -1277,7 +1277,7 @@ public void testVariantPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("b", "z", "9.9", "300")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("b", "z", "9.9", "300")) ); // 10 rows @@ -1485,7 +1485,7 @@ public double skipValuePredicateIndexScale() // circuit early and return nothing DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("0")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("0")) ); Assert.assertNull(singleTypeStringSupplier.as(DruidPredicateIndexes.class).forPredicate(predicateFactory)); Assert.assertNull(singleTypeLongSupplier.as(DruidPredicateIndexes.class).forPredicate(predicateFactory)); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java index ef572c8b6219..438c666227e6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java @@ -552,6 +552,8 @@ private static DimFilter toSimpleLeafFilter( return kind == SqlKind.IS_NOT_NULL ? new NotDimFilter(equalFilter) : equalFilter; } else if (kind == SqlKind.EQUALS || kind == SqlKind.NOT_EQUALS + || kind == SqlKind.IS_NOT_DISTINCT_FROM + || kind == SqlKind.IS_DISTINCT_FROM || kind == SqlKind.GREATER_THAN || kind == SqlKind.GREATER_THAN_OR_EQUAL || kind == SqlKind.LESS_THAN @@ -577,6 +579,8 @@ private static DimFilter toSimpleLeafFilter( switch (kind) { case EQUALS: case NOT_EQUALS: + case IS_NOT_DISTINCT_FROM: + case IS_DISTINCT_FROM: flippedKind = kind; break; case GREATER_THAN: @@ -688,9 +692,13 @@ private static DimFilter toSimpleLeafFilter( // Always use BoundDimFilters, to simplify filter optimization later (it helps to remember the comparator). switch (flippedKind) { case EQUALS: + case IS_NOT_DISTINCT_FROM: + // OK to treat EQUALS, IS_NOT_DISTINCT_FROM the same since we know stringVal is nonnull. filter = Bounds.equalTo(boundRefKey, stringVal); break; case NOT_EQUALS: + case IS_DISTINCT_FROM: + // OK to treat NOT_EQUALS, IS_DISTINCT_FROM the same since we know stringVal is nonnull. filter = new NotDimFilter(Bounds.equalTo(boundRefKey, stringVal)); break; case GREATER_THAN: @@ -724,9 +732,11 @@ private static DimFilter toSimpleLeafFilter( // Always use RangeFilter, to simplify filter optimization later switch (flippedKind) { case EQUALS: + case IS_NOT_DISTINCT_FROM: filter = Ranges.equalTo(rangeRefKey, val); break; case NOT_EQUALS: + case IS_DISTINCT_FROM: filter = new NotDimFilter(Ranges.equalTo(rangeRefKey, val)); break; case GREATER_THAN: diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java index 8e09ea0c7340..24fac69d11d2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java @@ -41,9 +41,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; -import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOperatorConversion { @@ -138,9 +136,14 @@ public DimFilter toDruidFilter( ); } } else { + final InDimFilter.ValuesSet valuesSet = InDimFilter.ValuesSet.create(); + for (final Object arrayElement : arrayElements) { + valuesSet.add(Evals.asString(arrayElement)); + } + return new InDimFilter( simpleExtractionExpr.getSimpleExtraction().getColumn(), - new InDimFilter.ValuesSet(Arrays.stream(arrayElements).map(Evals::asString).collect(Collectors.toList())), + valuesSet, simpleExtractionExpr.getSimpleExtraction().getExtractionFn(), null ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 16748b0b6ab5..e392ad8a47bf 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -371,6 +371,8 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new UnaryPrefixOperatorConversion(SqlStdOperatorTable.UNARY_MINUS, "-")) .add(new UnaryFunctionOperatorConversion(SqlStdOperatorTable.IS_NULL, "isnull")) .add(new UnaryFunctionOperatorConversion(SqlStdOperatorTable.IS_NOT_NULL, "notnull")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_DISTINCT_FROM, "isdistinctfrom")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, "notdistinctfrom")) .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_FALSE, "isfalse")) .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_TRUE, "istrue")) .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_NOT_FALSE, "notfalse")) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index 94c13f6a94c0..6dc8ff00531b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -38,11 +38,12 @@ import org.apache.calcite.rex.RexSlot; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.druid.java.util.common.Pair; +import org.apache.druid.error.DruidException; import org.apache.druid.query.LookupDataSource; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel; @@ -53,7 +54,6 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.Stack; @@ -242,7 +242,7 @@ private Optional analyzeCondition( ) { final List subConditions = decomposeAnd(condition); - final List> equalitySubConditions = new ArrayList<>(); + final List equalitySubConditions = new ArrayList<>(); final List literalSubConditions = new ArrayList<>(); final int numLeftFields = leftRowType.getFieldCount(); final Set rightColumns = new HashSet<>(); @@ -271,10 +271,12 @@ private Optional analyzeCondition( RexNode firstOperand; RexNode secondOperand; + SqlKind comparisonKind; if (subCondition.isA(SqlKind.INPUT_REF)) { firstOperand = rexBuilder.makeLiteral(true); secondOperand = subCondition; + comparisonKind = SqlKind.EQUALS; if (!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) { plannerContext.setPlanningError( @@ -285,11 +287,12 @@ private Optional analyzeCondition( return Optional.empty(); } - } else if (subCondition.isA(SqlKind.EQUALS)) { + } else if (subCondition.isA(SqlKind.EQUALS) || subCondition.isA(SqlKind.IS_NOT_DISTINCT_FROM)) { final List operands = ((RexCall) subCondition).getOperands(); Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%s]", operands.size()); firstOperand = operands.get(0); secondOperand = operands.get(1); + comparisonKind = subCondition.getKind(); } else { // If it's not EQUALS or a BOOLEAN input ref, it's not supported. plannerContext.setPlanningError( @@ -300,11 +303,11 @@ private Optional analyzeCondition( } if (isLeftExpression(firstOperand, numLeftFields) && isRightInputRef(secondOperand, numLeftFields)) { - equalitySubConditions.add(Pair.of(firstOperand, (RexInputRef) secondOperand)); + equalitySubConditions.add(new RexEquality(firstOperand, (RexInputRef) secondOperand, comparisonKind)); rightColumns.add((RexInputRef) secondOperand); } else if (isRightInputRef(firstOperand, numLeftFields) && isLeftExpression(secondOperand, numLeftFields)) { - equalitySubConditions.add(Pair.of(secondOperand, (RexInputRef) firstOperand)); + equalitySubConditions.add(new RexEquality(secondOperand, (RexInputRef) firstOperand, subCondition.getKind())); rightColumns.add((RexInputRef) firstOperand); } else { // Cannot handle this condition. @@ -336,7 +339,8 @@ && isLeftExpression(secondOperand, numLeftFields)) { numLeftFields, equalitySubConditions, literalSubConditions - )); + ) + ); } @VisibleForTesting @@ -375,7 +379,6 @@ private static boolean isRightInputRef(final RexNode rexNode, final int numLeftF return rexNode.isA(SqlKind.INPUT_REF) && ((RexInputRef) rexNode).getIndex() >= numLeftFields; } - @VisibleForTesting static class ConditionAnalysis { /** @@ -387,17 +390,16 @@ static class ConditionAnalysis /** * Each equality subcondition is an equality of the form f(LeftRel) = g(RightRel). */ - private final List> equalitySubConditions; + private final List equalitySubConditions; /** * Each literal subcondition is... a literal. */ private final List literalSubConditions; - ConditionAnalysis( int numLeftFields, - List> equalitySubConditions, + List equalitySubConditions, List literalSubConditions ) { @@ -417,9 +419,10 @@ public ConditionAnalysis pushThroughLeftProject(final Project leftProject) equalitySubConditions .stream() .map( - equality -> Pair.of( - RelOptUtil.pushPastProject(equality.lhs, leftProject), - (RexInputRef) RexUtil.shift(equality.rhs, rhsShift) + equality -> new RexEquality( + RelOptUtil.pushPastProject(equality.left, leftProject), + (RexInputRef) RexUtil.shift(equality.right, rhsShift), + equality.kind ) ) .collect(Collectors.toList()), @@ -436,15 +439,16 @@ public ConditionAnalysis pushThroughRightProject(final Project rightProject) equalitySubConditions .stream() .map( - equality -> Pair.of( - equality.lhs, + equality -> new RexEquality( + equality.left, (RexInputRef) RexUtil.shift( RelOptUtil.pushPastProject( - RexUtil.shift(equality.rhs, -numLeftFields), + RexUtil.shift(equality.right, -numLeftFields), rightProject ), numLeftFields - ) + ), + equality.kind ) ) .collect(Collectors.toList()), @@ -454,8 +458,8 @@ public ConditionAnalysis pushThroughRightProject(final Project rightProject) public boolean onlyUsesMappingsFromRightProject(final Project rightProject) { - for (Pair equality : equalitySubConditions) { - final int rightIndex = equality.rhs.getIndex() - numLeftFields; + for (final RexEquality equality : equalitySubConditions) { + final int rightIndex = equality.right.getIndex() - numLeftFields; if (!rightProject.getProjects().get(rightIndex).isA(SqlKind.INPUT_REF)) { return false; @@ -473,7 +477,7 @@ public RexNode getCondition(final RexBuilder rexBuilder) literalSubConditions, equalitySubConditions .stream() - .map(equality -> rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, equality.lhs, equality.rhs)) + .map(equality -> equality.makeCall(rexBuilder)) .collect(Collectors.toList()) ), false @@ -481,31 +485,55 @@ public RexNode getCondition(final RexBuilder rexBuilder) } @Override - public boolean equals(Object o) + public String toString() { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ConditionAnalysis that = (ConditionAnalysis) o; - return Objects.equals(equalitySubConditions, that.equalitySubConditions) && - Objects.equals(literalSubConditions, that.literalSubConditions); + return "ConditionAnalysis{" + + "numLeftFields=" + numLeftFields + + ", equalitySubConditions=" + equalitySubConditions + + ", literalSubConditions=" + literalSubConditions + + '}'; } + } - @Override - public int hashCode() + /** + * Like {@link org.apache.druid.segment.join.Equality} but uses {@link RexNode} instead of + * {@link org.apache.druid.math.expr.Expr}. + */ + static class RexEquality + { + private final RexNode left; + private final RexInputRef right; + private final SqlKind kind; + + public RexEquality(RexNode left, RexInputRef right, SqlKind kind) + { + this.left = left; + this.right = right; + this.kind = kind; + } + + public RexNode makeCall(final RexBuilder builder) { - return Objects.hash(equalitySubConditions, literalSubConditions); + final SqlOperator operator; + + if (kind == SqlKind.EQUALS) { + operator = SqlStdOperatorTable.EQUALS; + } else if (kind == SqlKind.IS_NOT_DISTINCT_FROM) { + operator = SqlStdOperatorTable.IS_NOT_DISTINCT_FROM; + } else { + throw DruidException.defensive("Unexpected operator kind[%s]", kind); + } + + return builder.makeCall(operator, left, right); } @Override public String toString() { - return "ConditionAnalysis{" + - "equalitySubConditions=" + equalitySubConditions + - ", literalSubConditions=" + literalSubConditions + + return "RexEquality{" + + "left=" + left + + ", right=" + right + + ", kind=" + kind + '}'; } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index d0d7935a334d..a0e96a876379 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -3619,6 +3619,105 @@ public void testLeftJoinWithNotNullFilter(Map queryContext) ); } + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testInnerJoin(Map queryContext) + { + testQuery( + "SELECT s.dim1, t.dim1\n" + + "FROM foo as s\n" + + "INNER JOIN foo as t " + + "ON s.dim1 = t.dim1", + queryContext, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns(ImmutableList.of("dim1")) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + "j0.", + "(\"dim1\" == \"j0.dim1\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim1", "j0.dim1") + .context(queryContext) + .build() + ), + sortIfSortBased( + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"", ""}, + new Object[]{"10.1", "10.1"}, + new Object[]{"2", "2"}, + new Object[]{"1", "1"}, + new Object[]{"def", "def"}, + new Object[]{"abc", "abc"} + ) + : ImmutableList.of( + new Object[]{"10.1", "10.1"}, + new Object[]{"2", "2"}, + new Object[]{"1", "1"}, + new Object[]{"def", "def"}, + new Object[]{"abc", "abc"} + ), + 0 + ) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testJoinWithExplicitIsNotDistinctFromCondition(Map queryContext) + { + // Like "testInnerJoin", but uses IS NOT DISTINCT FROM instead of equals. + + testQuery( + "SELECT s.dim1, t.dim1\n" + + "FROM foo as s\n" + + "INNER JOIN foo as t " + + "ON s.dim1 IS NOT DISTINCT FROM t.dim1", + queryContext, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns(ImmutableList.of("dim1")) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + "j0.", + "notdistinctfrom(\"dim1\",\"j0.dim1\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim1", "j0.dim1") + .context(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{"", ""}, + new Object[]{"10.1", "10.1"}, + new Object[]{"2", "2"}, + new Object[]{"1", "1"}, + new Object[]{"def", "def"}, + new Object[]{"abc", "abc"} + ) + ); + } + @Test @Parameters(source = QueryContextForJoinProvider.class) public void testInnerJoinSubqueryWithSelectorFilter(Map queryContext) @@ -4416,6 +4515,51 @@ public void testCountDistinctOfLookupUsingJoinOperator(Map query ); } + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testJoinWithImplicitIsNotDistinctFromCondition(Map queryContext) + { + // Like "testInnerJoin", but uses an implied is-not-distinct-from instead of equals. + cannotVectorize(); + + testQuery( + "SELECT x.m1, y.m1\n" + + "FROM foo x INNER JOIN foo y ON (x.m1 = y.m1) OR (x.m1 IS NULL AND y.m1 IS NULL)", + queryContext, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("m1") + .context(queryContext) + .build() + ), + "j0.", + "notdistinctfrom(\"m1\",\"j0.m1\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("j0.m1", "m1") + .context(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{1.0f, 1.0f}, + new Object[]{2.0f, 2.0f}, + new Object[]{3.0f, 3.0f}, + new Object[]{4.0f, 4.0f}, + new Object[]{5.0f, 5.0f}, + new Object[]{6.0f, 6.0f} + ) + ); + } + @Test @Parameters(source = QueryContextForJoinProvider.class) public void testJoinWithNonEquiCondition(Map queryContext) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 042b17368278..6516358b1b59 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -84,6 +84,7 @@ import org.apache.druid.query.filter.EqualityFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; +import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.RangeFilter; import org.apache.druid.query.filter.RegexDimFilter; @@ -5687,25 +5688,34 @@ public void testCountStarWithBoundFilterSimplifyOr() } @Test - public void testUnplannableTwoExactCountDistincts() + public void testUnplannableExactCountDistinctOnSketch() { - // Requires GROUPING SETS + GROUPING to be translated by AggregateExpandDistinctAggregatesRule. - + // COUNT DISTINCT on a sketch cannot be exact. assertQueryIsUnplannable( PLANNER_CONFIG_NO_HLL, - "SELECT dim2, COUNT(distinct dim1), COUNT(distinct dim2) FROM druid.foo GROUP BY dim2", - "SQL query requires 'IS NOT DISTINCT FROM' operator that is not supported." + "SELECT COUNT(distinct unique_dim1) FROM druid.foo", + "SQL requires a group-by on a column of type COMPLEX that is unsupported." ); } @Test - public void testUnplannableExactCountDistinctOnSketch() + public void testIsNotDistinctFromLiteral() { - // COUNT DISTINCT on a sketch cannot be exact. - assertQueryIsUnplannable( - PLANNER_CONFIG_NO_HLL, - "SELECT COUNT(distinct unique_dim1) FROM druid.foo", - "SQL requires a group-by on a column of type COMPLEX that is unsupported." + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE (dim1 >= 'a' and dim1 < 'b') OR dim1 IS NOT DISTINCT FROM 'ab'", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(range("dim1", ColumnType.STRING, "a", "b", false, true)) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1L} + ) ); } @@ -6726,7 +6736,117 @@ public void testExactCountDistinctWithGroupingAndOtherAggregators() } @Test - public void testMultipleExactCountDistinctWithGroupingAndOtherAggregators() + public void testMultipleExactCountDistinctWithGroupingAndOtherAggregatorsUsingJoin() + { + // When HLL is disabled, do multiple exact count distincts through joins of nested queries. + + testQuery( + PLANNER_CONFIG_NO_HLL, + "SELECT dim2, COUNT(*), COUNT(distinct dim1), COUNT(distinct cnt) FROM druid.foo GROUP BY dim2", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + join( + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions(new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING)) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .build() + ), + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource( + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING), + new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING) + ) + .build() + ) + ) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions(new DefaultDimensionSpec("d0", "_d0", ColumnType.STRING)) + .setAggregatorSpecs( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new NotDimFilter(isNull("d1", null)) + ) + ) + .build() + ), + "j0.", + "notdistinctfrom(\"d0\",\"j0._d0\")", + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setGranularity(Granularities.ALL) + .setDataSource( + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING), + new DefaultDimensionSpec("cnt", "d1", ColumnType.LONG) + ) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions(new DefaultDimensionSpec("d0", "_d0", ColumnType.STRING)) + .setAggregatorSpecs( + NullHandling.sqlCompatible() + ? new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new NotDimFilter(isNull("d1", null)) + ) + : new CountAggregatorFactory("a0") + ) + .build() + ), + "_j0.", + "notdistinctfrom(\"d0\",\"_j0._d0\")", + JoinType.INNER + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns(ImmutableList.of("_j0.a0", "a0", "d0", "j0.a0")) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{null, 2L, 2L, 1L}, + new Object[]{"", 1L, 1L, 1L}, + new Object[]{"a", 2L, 2L, 1L}, + new Object[]{"abc", 1L, 1L, 1L} + ) + : ImmutableList.of( + new Object[]{"", 3L, 3L, 1L}, + new Object[]{"a", 2L, 1L, 1L}, + new Object[]{"abc", 1L, 1L, 1L} + ) + ); + } + + @Test + public void testMultipleExactCountDistinctWithGroupingUsingGroupingSets() { notMsqCompatible(); requireMergeBuffers(4); @@ -12803,6 +12923,42 @@ public void testLookupWithNull() ); } + @Test + public void testLookupWithIsNotDistinctFromNull() + { + List expected; + if (useDefault) { + expected = ImmutableList.builder().add( + new Object[]{NULL_STRING, NULL_STRING}, + new Object[]{NULL_STRING, NULL_STRING}, + new Object[]{NULL_STRING, NULL_STRING} + ).build(); + } else { + expected = ImmutableList.builder().add( + new Object[]{NULL_STRING, NULL_STRING}, + new Object[]{NULL_STRING, NULL_STRING} + ).build(); + } + testQuery( + "SELECT dim2 ,lookup(dim2,'lookyloo') from foo where dim2 is not distinct from null", + ImmutableList.of( + new Druids.ScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn("v0", "null", ColumnType.STRING) + ) + .columns("v0") + .legacy(false) + .filters(isNull("dim2")) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + expected + ); + } + @Test public void testRoundFunc() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java index cd421ef02b1c..67a1c0ccd43f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java @@ -186,7 +186,7 @@ public void testGroupByWithSortOnPostAggregationNoTopNContext() @Override @Ignore - public void testUnplannableTwoExactCountDistincts() + public void testMultipleExactCountDistinctWithGroupingAndOtherAggregatorsUsingJoin() { }