diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java index bd4dcd4d9ad8..9db22a2ec59d 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java @@ -157,4 +157,27 @@ public void planNotEquals(Blackhole blackhole) blackhole.consume(plannerResult); } } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public void planEqualsInsideAndOutsideCase(Blackhole blackhole) + { + final String sql = StringUtils.format( + "SELECT COUNT(*) FROM foo\n" + + "WHERE\n" + + " CASE WHEN LOOKUP(dimZipf, 'benchmark-lookup', 'N/A') = '%s'\n" + + " THEN NULL\n" + + " ELSE LOOKUP(dimZipf, 'benchmark-lookup', 'N/A')\n" + + " END IN ('%s', '%s', '%s')", + LookupBenchmarkUtil.makeKeyOrValue(0), + LookupBenchmarkUtil.makeKeyOrValue(1), + LookupBenchmarkUtil.makeKeyOrValue(2), + LookupBenchmarkUtil.makeKeyOrValue(3) + ); + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, ImmutableMap.of())) { + final PlannerResult plannerResult = planner.plan(); + blackhole.consume(plannerResult); + } + } } diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java index 4322e7273881..ce01324116f9 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java @@ -26,6 +26,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.data.input.impl.DimensionSchema; import org.apache.druid.data.input.impl.DimensionsSpec; +import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; @@ -36,6 +37,7 @@ import org.apache.druid.segment.AutoTypeColumnSchema; import org.apache.druid.segment.IndexSpec; import org.apache.druid.segment.QueryableIndex; +import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.generator.GeneratorBasicSchemas; import org.apache.druid.segment.generator.GeneratorSchemaInfo; import org.apache.druid.segment.generator.SegmentGenerator; @@ -204,7 +206,7 @@ public void setup() throws JsonProcessingException ); String prefix = ("explain plan for select long1 from foo where long1 in "); - final String sql = createQuery(prefix, inClauseLiteralsCount); + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.LONG); final Sequence resultSequence = getPlan(sql, null); final Object[] planResult = resultSequence.toList().get(0); @@ -222,12 +224,13 @@ public void tearDown() throws Exception closer.close(); } + @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) public void queryInSql(Blackhole blackhole) { String prefix = "explain plan for select long1 from foo where long1 in "; - final String sql = createQuery(prefix, inClauseLiteralsCount); + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.LONG); getPlan(sql, blackhole); } @@ -238,7 +241,7 @@ public void queryEqualOrInSql(Blackhole blackhole) { String prefix = "explain plan for select long1 from foo where string1 = '7' or long1 in "; - final String sql = createQuery(prefix, inClauseLiteralsCount); + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.LONG); getPlan(sql, blackhole); } @@ -250,28 +253,74 @@ public void queryMultiEqualOrInSql(Blackhole blackhole) { String prefix = "explain plan for select long1 from foo where string1 = '7' or string1 = '8' or long1 in "; - final String sql = createQuery(prefix, inClauseLiteralsCount); + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.LONG); getPlan(sql, blackhole); } @Benchmark @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.MILLISECONDS) - public void queryJoinEqualOrInSql(Blackhole blackhole) + public void queryStringFunctionInSql(Blackhole blackhole) { + String prefix = + "explain plan for select count(*) from foo where long1 = 8 or lower(string1) in "; + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.STRING); + getPlan(sql, blackhole); + } + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public void queryStringFunctionIsNotNullAndNotInSql(Blackhole blackhole) + { + String prefix = + "explain plan for select count(*) from foo where long1 = 8 and lower(string1) is not null and lower(string1) not in "; + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.STRING); + getPlan(sql, blackhole); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public void queryStringFunctionIsNullOrInSql(Blackhole blackhole) + { + String prefix = + "explain plan for select count(*) from foo where long1 = 8 and (lower(string1) is null or lower(string1) in "; + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.STRING) + ')'; + getPlan(sql, blackhole); + } + + @Benchmark + @BenchmarkMode(Mode.AverageTime) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public void queryJoinEqualOrInSql(Blackhole blackhole) + { String prefix = "explain plan for select foo.long1, fooright.string1 from foo inner join foo as fooright on foo.string1 = fooright.string1 where fooright.string1 = '7' or foo.long1 in "; - final String sql = createQuery(prefix, inClauseLiteralsCount); + final String sql = createQuery(prefix, inClauseLiteralsCount, ValueType.LONG); getPlan(sql, blackhole); } - private String createQuery(String prefix, int inClauseLiteralsCount) + private String createQuery(String prefix, int inClauseLiteralsCount, ValueType type) { StringBuilder sqlBuilder = new StringBuilder(); sqlBuilder.append(prefix).append('('); - IntStream.range(1, inClauseLiteralsCount - 1).forEach(i -> sqlBuilder.append(i).append(",")); - sqlBuilder.append(inClauseLiteralsCount).append(")"); + IntStream.range(1, inClauseLiteralsCount + 1).forEach( + i -> { + if (i > 1) { + sqlBuilder.append(','); + } + + if (type == ValueType.LONG) { + sqlBuilder.append(i); + } else if (type == ValueType.STRING) { + sqlBuilder.append("'").append(i).append("'"); + } else { + throw new ISE("Cannot generate IN with type[%s]", type); + } + } + ); + sqlBuilder.append(")"); return sqlBuilder.toString(); } diff --git a/docs/querying/sql-query-context.md b/docs/querying/sql-query-context.md index f8b1576a913b..dc73c9e1ab3f 100644 --- a/docs/querying/sql-query-context.md +++ b/docs/querying/sql-query-context.md @@ -52,7 +52,9 @@ Configure Druid SQL query planning using the parameters in the table below. |`sqlPullUpLookup`|Whether to consider the [pull-up rewrite](lookups.md#pull-up) of the `LOOKUP` function during SQL planning.|true| |`enableJoinLeftTableScanDirect`|`false`|This flag applies to queries which have joins. For joins, where left child is a simple scan with a filter, by default, druid will run the scan as a query and the join the results to the right child on broker. Setting this flag to true overrides that behavior and druid will attempt to push the join to data servers instead. Please note that the flag could be applicable to queries even if there is no explicit join. since queries can internally translated into a join by the SQL planner.| |`maxNumericInFilters`|`-1`|Max limit for the amount of numeric values that can be compared for a string type dimension when the entire SQL WHERE clause of a query translates only to an [OR](../querying/filters.md#or) of [Bound filter](../querying/filters.md#bound-filter). By default, Druid does not restrict the amount of of numeric Bound Filters on String columns, although this situation may block other queries from running. Set this parameter to a smaller value to prevent Druid from running queries that have prohibitively long segment processing times. The optimal limit requires some trial and error; we recommend starting with 100. Users who submit a query that exceeds the limit of `maxNumericInFilters` should instead rewrite their queries to use strings in the `WHERE` clause instead of numbers. For example, `WHERE someString IN (‘123’, ‘456’)`. This value cannot exceed the set system configuration `druid.sql.planner.maxNumericInFilters`. This value is ignored if `druid.sql.planner.maxNumericInFilters` is not set explicitly.| -|`inSubQueryThreshold`|`2147483647`| Threshold for minimum number of values in an IN clause to convert the query to a JOIN operation on an inlined table rather than a predicate. A threshold of 0 forces usage of an inline table in all cases; a threshold of [Integer.MAX_VALUE] forces usage of OR in all cases. | +|`inFunctionThreshold`|`100`| At or beyond this threshold number of values, SQL `IN` is converted to [`SCALAR_IN_ARRAY`](sql-functions.md#scalar_in_array). A threshold of 0 forces this conversion in all cases. A threshold of [Integer.MAX_VALUE] disables this conversion. The converted function is eligible for fewer planning-time optimizations, which speeds up planning, but may prevent certain planning-time optimizations.| +|`inFunctionExprThreshold`|`2`| At or beyond this threshold number of values, SQL `IN` is eligible for execution using the native function `scalar_in_array` rather than an || of `==`, even if the number of values is below `inFunctionThreshold`. This property only affects translation of SQL `IN` to a [native expression](math-expr.md). It does not affect translation of SQL `IN` to a [native filter](filters.md). This property is provided for backwards compatibility purposes, and may be removed in a future release.| +|`inSubQueryThreshold`|`2147483647`| At or beyond this threshold number of values, SQL `IN` is converted to `JOIN` on an inline table. `inFunctionThreshold` takes priority over this setting. A threshold of 0 forces usage of an inline table in all cases where the size of a SQL `IN` is larger than `inFunctionThreshold`. A threshold of `2147483647` disables the rewrite of SQL `IN` to `JOIN`. | ## Setting the query context The query context parameters can be specified as a "context" object in the [JSON API](../api-reference/sql-api.md) or as a [JDBC connection properties object](../api-reference/sql-jdbc.md). diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index 5c08678b8884..80a860bb273f 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -24,10 +24,11 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.query.QueryContexts.Vectorize; +import org.apache.druid.query.filter.InDimFilter; +import org.apache.druid.query.filter.TypedInFilter; import org.apache.druid.segment.QueryableIndexStorageAdapter; import javax.annotation.Nullable; - import java.io.IOException; import java.util.Collections; import java.util.Map; @@ -575,6 +576,35 @@ public int getInSubQueryThreshold(int defaultValue) ); } + /** + * At or above this threshold number of values, when planning SQL queries, use the SQL SCALAR_IN_ARRAY operator rather + * than a stack of SQL ORs. This speeds up planning for large sets of points because it is opaque to various + * expensive optimizations. But, because this does bypass certain optimizations, we only do the transformation above + * a certain threshold. The SCALAR_IN_ARRAY operator is still able to convert to {@link InDimFilter} or + * {@link TypedInFilter}. + */ + public int getInFunctionThreshold() + { + return getInt( + QueryContexts.IN_FUNCTION_THRESHOLD, + QueryContexts.DEFAULT_IN_FUNCTION_THRESHOLD + ); + } + + /** + * At or above this threshold, when converting the SEARCH operator to a native expression, use the "scalar_in_array" + * function rather than a sequence of equals (==) separated by or (||). This is typically a lower threshold + * than {@link #getInFunctionThreshold()}, because it does not prevent any SQL planning optimizations, and it + * speeds up query execution. + */ + public int getInFunctionExprThreshold() + { + return getInt( + QueryContexts.IN_FUNCTION_EXPR_THRESHOLD, + QueryContexts.DEFAULT_IN_FUNCTION_EXPR_THRESHOLD + ); + } + public boolean isTimeBoundaryPlanningEnabled() { return getBoolean( diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 2ea31d339485..3010b4fa923c 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -77,6 +77,8 @@ public class QueryContexts public static final String BY_SEGMENT_KEY = "bySegment"; public static final String BROKER_SERVICE_NAME = "brokerService"; public static final String IN_SUB_QUERY_THRESHOLD_KEY = "inSubQueryThreshold"; + public static final String IN_FUNCTION_THRESHOLD = "inFunctionThreshold"; + public static final String IN_FUNCTION_EXPR_THRESHOLD = "inFunctionExprThreshold"; public static final String TIME_BOUNDARY_PLANNING_KEY = "enableTimeBoundaryPlanning"; public static final String POPULATE_CACHE_KEY = "populateCache"; public static final String POPULATE_RESULT_LEVEL_CACHE_KEY = "populateResultLevelCache"; @@ -120,6 +122,8 @@ public class QueryContexts public static final boolean DEFAULT_SECONDARY_PARTITION_PRUNING = true; public static final boolean DEFAULT_ENABLE_DEBUG = false; public static final int DEFAULT_IN_SUB_QUERY_THRESHOLD = Integer.MAX_VALUE; + public static final int DEFAULT_IN_FUNCTION_THRESHOLD = 100; + public static final int DEFAULT_IN_FUNCTION_EXPR_THRESHOLD = 2; public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false; public static final boolean DEFAULT_WINDOWING_STRICT_VALIDATION = true; diff --git a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java index 71b477d16c37..c555c2ed4372 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java @@ -41,7 +41,6 @@ import org.junit.Test; import javax.annotation.Nullable; - import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -337,11 +336,35 @@ public void testGetMaxSubqueryBytes() ImmutableMap.of(QueryContexts.MAX_SUBQUERY_BYTES_KEY, "auto") ); assertEquals("auto", context2.getMaxSubqueryMemoryBytes(null)); - + final QueryContext context3 = new QueryContext(ImmutableMap.of()); assertEquals("disabled", context3.getMaxSubqueryMemoryBytes("disabled")); } + @Test + public void testGetInFunctionThreshold() + { + final QueryContext context1 = new QueryContext( + ImmutableMap.of(QueryContexts.IN_FUNCTION_THRESHOLD, Integer.MAX_VALUE) + ); + assertEquals(Integer.MAX_VALUE, context1.getInFunctionThreshold()); + + final QueryContext context2 = QueryContext.empty(); + assertEquals(QueryContexts.DEFAULT_IN_FUNCTION_THRESHOLD, context2.getInFunctionThreshold()); + } + + @Test + public void testGetInFunctionExprThreshold() + { + final QueryContext context1 = new QueryContext( + ImmutableMap.of(QueryContexts.IN_FUNCTION_EXPR_THRESHOLD, Integer.MAX_VALUE) + ); + assertEquals(Integer.MAX_VALUE, context1.getInFunctionExprThreshold()); + + final QueryContext context2 = QueryContext.empty(); + assertEquals(QueryContexts.DEFAULT_IN_FUNCTION_EXPR_THRESHOLD, context2.getInFunctionExprThreshold()); + } + @Test public void testDefaultEnableQueryDebugging() { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java index 967978c8760e..20202b1a280f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java @@ -24,6 +24,7 @@ import com.google.common.collect.Range; import com.google.common.collect.RangeSet; import com.google.common.collect.TreeRangeSet; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexLiteral; @@ -48,6 +49,7 @@ import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; @@ -82,7 +84,7 @@ public DimFilter toDruidFilter( plannerContext, rowSignature, virtualColumnRegistry, - expandSearch((RexCall) rexNode, REX_BUILDER) + expandSearch((RexCall) rexNode, REX_BUILDER, plannerContext.queryContext().getInFunctionThreshold()) ); } @@ -97,7 +99,7 @@ public DruidExpression toDruidExpression( return Expressions.toDruidExpression( plannerContext, rowSignature, - expandSearch((RexCall) rexNode, REX_BUILDER) + expandSearch((RexCall) rexNode, REX_BUILDER, plannerContext.queryContext().getInFunctionExprThreshold()) ); } @@ -111,7 +113,8 @@ public DruidExpression toDruidExpression( */ public static RexNode expandSearch( final RexCall call, - final RexBuilder rexBuilder + final RexBuilder rexBuilder, + final int scalarInArrayThreshold ) { final RexNode arg = call.operands.get(0); @@ -139,13 +142,10 @@ public static RexNode expandSearch( notInPoints = getPoints(complement); notInRexNode = makeIn( arg, - ImmutableList.copyOf( - Iterables.transform( - notInPoints, - point -> rexBuilder.makeLiteral(point, sargRex.getType(), true, true) - ) - ), + notInPoints, + sargRex.getType(), true, + notInPoints.size() >= scalarInArrayThreshold, rexBuilder ); } @@ -155,13 +155,10 @@ public static RexNode expandSearch( sarg.pointCount == 0 ? Collections.emptyList() : (List) getPoints(sarg.rangeSet); final RexNode inRexNode = makeIn( arg, - ImmutableList.copyOf( - Iterables.transform( - inPoints, - point -> rexBuilder.makeLiteral(point, sargRex.getType(), true, true) - ) - ), + inPoints, + sargRex.getType(), false, + inPoints.size() >= scalarInArrayThreshold, rexBuilder ); if (inRexNode != null) { @@ -225,14 +222,36 @@ public static RexNode expandSearch( return retVal; } + /** + * Make an IN condition for an "arg" matching certain "points", as in "arg IN (points)". + * + * @param arg lhs of the IN + * @param pointObjects rhs of the IN. Must match the "pointType" + * @param pointType type of "pointObjects" + * @param negate true for NOT IN, false for IN + * @param useScalarInArray if true, use {@link ScalarInArrayOperatorConversion#SQL_FUNCTION} when there is more + * than one point; if false, use a stack of ORs + * @param rexBuilder rex builder + * + * @return SQL rex nodes equivalent to the IN filter, or null if "pointObjects" is empty + */ @Nullable public static RexNode makeIn( final RexNode arg, - final List points, + final Collection pointObjects, + final RelDataType pointType, final boolean negate, + final boolean useScalarInArray, final RexBuilder rexBuilder ) { + final List points = ImmutableList.copyOf( + Iterables.transform( + pointObjects, + point -> rexBuilder.makeLiteral(point, pointType, false, false) + ) + ); + if (points.isEmpty()) { return null; } else if (points.size() == 1) { @@ -244,22 +263,33 @@ public static RexNode makeIn( return rexBuilder.makeCall(negate ? SqlStdOperatorTable.NOT_EQUALS : SqlStdOperatorTable.EQUALS, arg, point); } } else { - // x = a || x = b || x = c ... - RexNode retVal = rexBuilder.makeCall( - SqlStdOperatorTable.OR, - ImmutableList.copyOf( - Iterables.transform( - points, - point -> { - if (RexUtil.isNullLiteral(point, true)) { - return rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, arg); - } else { - return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, arg, point); + RexNode retVal; + + if (useScalarInArray) { + // SCALAR_IN_ARRAY(x, ARRAY[a, b, c]) + retVal = rexBuilder.makeCall( + ScalarInArrayOperatorConversion.SQL_FUNCTION, + arg, + rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, points) + ); + } else { + // x = a || x = b || x = c ... + retVal = rexBuilder.makeCall( + SqlStdOperatorTable.OR, + ImmutableList.copyOf( + Iterables.transform( + points, + point -> { + if (RexUtil.isNullLiteral(point, true)) { + return rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, arg); + } else { + return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, arg, point); + } } - } - ) - ) - ); + ) + ) + ); + } if (negate) { retVal = rexBuilder.makeCall(SqlStdOperatorTable.NOT, retVal); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java index b4f006ce97ef..d1a47520a90e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java @@ -60,10 +60,12 @@ import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.Types; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion; import org.apache.druid.sql.calcite.parser.DruidSqlIngest; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; @@ -774,6 +776,59 @@ public void validateCall(SqlCall call, SqlValidatorScope scope) super.validateCall(call, scope); } + @Override + protected SqlNode performUnconditionalRewrites(SqlNode node, final boolean underFrom) + { + if (node != null && (node.getKind() == SqlKind.IN || node.getKind() == SqlKind.NOT_IN)) { + final SqlNode rewritten = rewriteInToScalarInArrayIfNeeded((SqlCall) node, underFrom); + //noinspection ObjectEquality + if (rewritten != node) { + return rewritten; + } + } + + return super.performUnconditionalRewrites(node, underFrom); + } + + /** + * Rewrites "x IN (values)" to "SCALAR_IN_ARRAY(x, values)", if appropriate. Checks the form of the IN and checks + * the value of {@link QueryContext#getInFunctionThreshold()}. + * + * @param call call to {@link SqlKind#IN} or {@link SqlKind#NOT_IN} + * @param underFrom underFrom arg from {@link #performUnconditionalRewrites(SqlNode, boolean)}, used for + * recursive calls + * + * @return rewritten call, or the original call if no rewrite was appropriate + */ + private SqlNode rewriteInToScalarInArrayIfNeeded(final SqlCall call, final boolean underFrom) + { + if (call.getOperandList().size() == 2 && call.getOperandList().get(1) instanceof SqlNodeList) { + // expr IN (values) + final SqlNode exprNode = call.getOperandList().get(0); + final SqlNodeList valuesNode = (SqlNodeList) call.getOperandList().get(1); + + // Confirm valuesNode is big enough to convert to SCALAR_IN_ARRAY, and references only nonnull literals. + // (Can't include NULL literals in the conversion, because SCALAR_IN_ARRAY matches NULLs as if they were regular + // values, whereas IN does not.) + if (valuesNode.size() > plannerContext.queryContext().getInFunctionThreshold() + && valuesNode.stream().allMatch(node -> node.getKind() == SqlKind.LITERAL && !SqlUtil.isNull(node))) { + final SqlCall newCall = ScalarInArrayOperatorConversion.SQL_FUNCTION.createCall( + call.getParserPosition(), + performUnconditionalRewrites(exprNode, underFrom), + SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR.createCall(valuesNode) + ); + + if (call.getKind() == SqlKind.NOT_IN) { + return SqlStdOperatorTable.NOT.createCall(call.getParserPosition(), newCall); + } else { + return newCall; + } + } + } + + return call; + } + public static CalciteContextException buildCalciteContextException(String message, SqlNode call) { return buildCalciteContextException(new CalciteException(message, null), message, call); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java index 95ad2b11334b..30329816f042 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java @@ -46,6 +46,7 @@ import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.sql.calcite.expression.builtin.MultiValueStringOperatorConversions; import org.apache.druid.sql.calcite.expression.builtin.QueryLookupOperatorConversion; +import org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion; import org.apache.druid.sql.calcite.expression.builtin.SearchOperatorConversion; import org.apache.druid.sql.calcite.filtration.CollectComparisons; import org.apache.druid.sql.calcite.planner.Calcites; @@ -275,12 +276,16 @@ private RexNode visitAnd(final RexCall call) } /** - * When we encounter SEARCH, expand it using {@link SearchOperatorConversion#expandSearch(RexCall, RexBuilder)} + * When we encounter SEARCH, expand it using {@link SearchOperatorConversion#expandSearch(RexCall, RexBuilder, int)} * and continue processing what lies beneath. */ private RexNode visitSearch(final RexCall call) { - final RexNode expanded = SearchOperatorConversion.expandSearch(call, rexBuilder); + final RexNode expanded = SearchOperatorConversion.expandSearch( + call, + rexBuilder, + plannerContext.queryContext().getInFunctionThreshold() + ); if (expanded instanceof RexCall) { final RexNode converted = visitCall((RexCall) expanded); @@ -300,10 +305,17 @@ private RexNode visitSearch(final RexCall call) */ private RexNode visitComparison(final RexCall call) { - return CollectionUtils.getOnlyElement( + final RexNode retVal = CollectionUtils.getOnlyElement( new CollectReverseLookups(Collections.singletonList(call), rexBuilder).collect(), ret -> new ISE("Expected to collect single node, got[%s]", ret) ); + + //noinspection ObjectEquality + if (retVal != call) { + return retVal; + } else { + return super.visitCall(call); + } } /** @@ -398,12 +410,13 @@ protected Set getMatchValues(RexCall call) return Collections.singleton(null); } else { // Compute the set of values that this comparison operator matches. - // Note that MV_CONTAINS and MV_OVERLAP match nulls, but other comparison operators do not. + // Note that MV_CONTAINS, MV_OVERLAP, and SCALAR_IN_ARRAY match nulls, but other comparison operators do not. // See "isBinaryComparison" for the set of operators we might encounter here. final RexNode matchLiteral = call.getOperands().get(1); final boolean matchNulls = call.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.calciteOperator()) - || call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator()); + || call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator()) + || call.getOperator().equals(ScalarInArrayOperatorConversion.SQL_FUNCTION); return toStringSet(matchLiteral, matchNulls); } } @@ -559,8 +572,16 @@ private RexNode makeMatchCondition( } else { return SearchOperatorConversion.makeIn( reverseLookupKey.arg, - stringsToRexNodes(reversedMatchValues, rexBuilder), + reversedMatchValues, + rexBuilder.getTypeFactory() + .createTypeWithNullability( + rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), + true + ), reverseLookupKey.negate, + + // Use regular equals, or SCALAR_IN_ARRAY, depending on inFunctionThreshold. + reversedMatchValues.size() >= plannerContext.queryContext().getInFunctionThreshold(), rexBuilder ); } @@ -598,7 +619,8 @@ private static boolean isBinaryComparison(final RexNode rexNode) return call.getKind() == SqlKind.EQUALS || call.getKind() == SqlKind.NOT_EQUALS || call.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.calciteOperator()) - || call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator()); + || call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator()) + || call.getOperator().equals(ScalarInArrayOperatorConversion.SQL_FUNCTION); } else { return false; } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java index 1aa4c89b416b..80d9f1bbf170 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java @@ -169,6 +169,30 @@ public void testFilterInLookupOfConcat() ); } + @Test + public void testFilterScalarInArrayLookupOfConcat() + { + cannotVectorize(); + + testQuery( + buildFilterTestSql("SCALAR_IN_ARRAY(LOOKUP(CONCAT(dim1, 'a', dim2), 'lookyloo'), ARRAY['xa', 'xabc'])"), + QUERY_CONTEXT, + buildFilterTestExpectedQuery( + or( + and( + equality("dim1", "", ColumnType.STRING), + equality("dim2", "", ColumnType.STRING) + ), + and( + equality("dim1", "", ColumnType.STRING), + equality("dim2", "bc", ColumnType.STRING) + ) + ) + ), + ImmutableList.of() + ); + } + @Test public void testFilterConcatOfLookup() { @@ -378,6 +402,40 @@ public void testFilterIn() ); } + @Test + public void testFilterScalarInArray() + { + cannotVectorize(); + + testQuery( + buildFilterTestSql("SCALAR_IN_ARRAY(LOOKUP(dim1, 'lookyloo'), ARRAY['xabc', 'x6', 'nonexistent'])"), + QUERY_CONTEXT, + buildFilterTestExpectedQuery(in("dim1", Arrays.asList("6", "abc"))), + ImmutableList.of(new Object[]{"xabc", 1L}) + ); + } + + @Test + public void testFilterInOverScalarInArrayThreshold() + { + cannotVectorize(); + + // Set inFunctionThreshold = 1 to cause the IN to be converted to SCALAR_IN_ARRAY. + final ImmutableMap queryContext = + ImmutableMap.builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(PlannerContext.CTX_SQL_REVERSE_LOOKUP, true) + .put(QueryContexts.IN_FUNCTION_THRESHOLD, 1) + .build(); + + testQuery( + buildFilterTestSql("LOOKUP(dim1, 'lookyloo') IN ('xabc', 'x6', 'nonexistent')"), + queryContext, + buildFilterTestExpectedQuery(in("dim1", Arrays.asList("6", "abc"))), + ImmutableList.of(new Object[]{"xabc", 1L}) + ); + } + @Test public void testFilterInOverMaxSize() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 9b302534ef62..7a3df49eedfa 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import org.apache.calcite.runtime.CalciteContextException; import org.apache.druid.common.config.NullHandling; import org.apache.druid.error.DruidException; @@ -4039,6 +4041,49 @@ public void testGroupingWithNullInFilter() ); } + @Test + public void testGroupingWithNullPlusNonNullInFilter() + { + msqIncompatible(); + testQuery( + "SELECT COUNT(*) FROM foo WHERE dim1 IN (NULL, 'abc')", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(equality("dim1", "abc", ColumnType.STRING)) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{1L}) + ); + } + + @Test + public void testGroupingWithNotNullPlusNonNullInFilter() + { + msqIncompatible(); + testQuery( + "SELECT COUNT(*) FROM foo WHERE dim1 NOT IN (NULL, 'abc')", + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + InlineDataSource.fromIterable( + ImmutableList.of(new Object[]{0L}), + RowSignature.builder().add("EXPR$0", ColumnType.LONG).build() + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("EXPR$0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{0L}) + ); + } @Test public void testGroupByNothingWithLiterallyFalseFilter() @@ -5557,6 +5602,46 @@ public void testNotInOrIsNullFilter() ); } + @Test + public void testNotInAndIsNotNullFilter() + { + testQuery( + "SELECT dim1, COUNT(*) FROM druid.foo " + + "WHERE dim1 NOT IN ('ghi', 'abc', 'def') AND dim1 IS NOT NULL " + + "GROUP BY dim1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"))) + .setDimFilter(and( + notNull("dim1"), + not(in("dim1", ColumnType.STRING, ImmutableList.of("abc", "def", "ghi"))) + )) + .setAggregatorSpecs( + aggregators( + new CountAggregatorFactory("a0") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"1", 1L}, + new Object[]{"10.1", 1L}, + new Object[]{"2", 1L} + ) + : ImmutableList.of( + new Object[]{"1", 1L}, + new Object[]{"10.1", 1L}, + new Object[]{"2", 1L} + ) + ); + } + @Test public void testNotInAndLessThanFilter() { @@ -5631,6 +5716,279 @@ public void testInIsNotTrueAndLessThanFilter() ); } + @Test + public void testInExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 IN ('abc', 'def', 'ghi'), COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "scalar_in_array(\"dim1\",array('abc','def','ghi'))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, 4L}, + new Object[]{true, 2L} + ) + ); + } + + @Test + public void testInExpressionBelowThreshold() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 IN ('abc', 'def', 'ghi'), COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + QueryContexts.override(QUERY_CONTEXT_DEFAULT, QueryContexts.IN_FUNCTION_EXPR_THRESHOLD, 100), + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "((\"dim1\" == 'abc') || (\"dim1\" == 'def') || (\"dim1\" == 'ghi'))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, 4L}, + new Object[]{true, 2L} + ) + ); + } + + @Test + public void testInOrIsNullExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 IN ('abc', 'def', 'ghi') OR dim1 IS NULL, COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "(isnull(\"dim1\") || scalar_in_array(\"dim1\",array('abc','def','ghi')))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, NullHandling.sqlCompatible() ? 4L : 3L}, + new Object[]{true, NullHandling.sqlCompatible() ? 2L : 3L} + ) + ); + } + + @Test + public void testNotInOrIsNullExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT NOT (dim1 IN ('abc', 'def', 'ghi') OR dim1 IS NULL), COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "(notnull(\"dim1\") && (! scalar_in_array(\"dim1\",array('abc','def','ghi'))))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, NullHandling.sqlCompatible() ? 2L : 3L}, + new Object[]{true, NullHandling.sqlCompatible() ? 4L : 3L} + ) + ); + } + + @Test + public void testNotInAndIsNotNullExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 NOT IN ('abc', 'def', 'ghi') AND dim1 IS NOT NULL, COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "(notnull(\"dim1\") && (! scalar_in_array(\"dim1\",array('abc','def','ghi'))))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, NullHandling.sqlCompatible() ? 2L : 3L}, + new Object[]{true, NullHandling.sqlCompatible() ? 4L : 3L} + ) + ); + } + + @Test + public void testInOrGreaterThanExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 IN ('abc', 'def', 'ghi') OR dim1 > 'zzz', COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "(scalar_in_array(\"dim1\",array('abc','def','ghi')) || (\"dim1\" > 'zzz'))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, 4L}, + new Object[]{true, 2L} + ) + ); + } + + @Test + public void testNotInAndLessThanExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 NOT IN ('abc', 'def', 'ghi') AND dim1 < 'zzz', COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "((\"dim1\" < 'zzz') && (! scalar_in_array(\"dim1\",array('abc','def','ghi'))))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, 2L}, + new Object[]{true, 4L} + ) + ); + } + + @Test + public void testNotInOrEqualToOneOfThemExpression() + { + // Cannot vectorize scalar_in_array expression. + cannotVectorize(); + + testQuery( + "SELECT dim1 NOT IN ('abc', 'def', 'ghi') OR dim1 = 'def', COUNT(*)\n" + + "FROM druid.foo\n" + + "GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "(! scalar_in_array(\"dim1\",array('abc','ghi')))", + ColumnType.LONG + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{false, 1L}, + new Object[]{true, 5L} + ) + ); + } + @Test public void testSqlIsNullToInFilter() { @@ -5685,14 +6043,91 @@ public void testInFilterWith23Elements() final String elementsString = Joiner.on(",").join(elements.stream().map(s -> "'" + s + "'").iterator()); testQuery( - "SELECT dim1, COUNT(*) FROM druid.foo WHERE dim1 IN (" + elementsString + ") GROUP BY dim1", + "SELECT dim1, COUNT(*) FROM druid.foo\n" + + "WHERE dim1 IN (" + elementsString + ") OR dim1 = 'xyz' OR dim1 IS NULL\n" + + "GROUP BY dim1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"))) + .setDimFilter( + NullHandling.sqlCompatible() + ? or( + in("dim1", ImmutableSet.builder().addAll(elements).add("xyz").build()), + isNull("dim1") + ) + : in( + "dim1", + Lists.newArrayList( + Iterables.concat( + Collections.singleton(null), + elements, + Collections.singleton("xyz") + ) + ) + ) + ) + .setAggregatorSpecs( + aggregators( + new CountAggregatorFactory("a0") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) + : ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) + ); + } + + @Test + public void testInFilterWith23Elements_overScalarInArrayThreshold() + { + final List elements = new ArrayList<>(); + elements.add("abc"); + elements.add("def"); + elements.add("ghi"); + for (int i = 0; i < 20; i++) { + elements.add("dummy" + i); + } + + final String elementsString = Joiner.on(",").join(elements.stream().map(s -> "'" + s + "'").iterator()); + + testQuery( + "SELECT dim1, COUNT(*) FROM druid.foo\n" + + "WHERE dim1 IN (" + elementsString + ") OR dim1 = 'xyz' OR dim1 IS NULL\n" + + "GROUP BY dim1", + QueryContexts.override(QUERY_CONTEXT_DEFAULT, QueryContexts.IN_FUNCTION_THRESHOLD, 20), ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE1) .setInterval(querySegmentSpec(Filtration.eternity())) .setGranularity(Granularities.ALL) .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"))) - .setDimFilter(in("dim1", elements)) + .setDimFilter( + // [dim1 = xyz] is not combined into the IN filter, because SCALAR_IN_ARRAY was used, + // and it is opaque to most optimizations. (That's its main purpose.) + NullHandling.sqlCompatible() + ? or( + in("dim1", elements), + isNull("dim1"), + equality("dim1", "xyz", ColumnType.STRING) + ) + : or( + in("dim1", Arrays.asList(null, "xyz")), + in("dim1", elements) + ) + ) .setAggregatorSpecs( aggregators( new CountAggregatorFactory("a0") @@ -5701,7 +6136,80 @@ public void testInFilterWith23Elements() .setContext(QUERY_CONTEXT_DEFAULT) .build() ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) + : ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) + ); + } + + @Test + public void testInFilterWith23Elements_overBothScalarInArrayAndInSubQueryThresholds() + { + // Verify that when an IN filter surpasses both inFunctionThreshold and inSubQueryThreshold, the + // inFunctionThreshold takes priority. + final List elements = new ArrayList<>(); + elements.add("abc"); + elements.add("def"); + elements.add("ghi"); + for (int i = 0; i < 20; i++) { + elements.add("dummy" + i); + } + + final String elementsString = Joiner.on(",").join(elements.stream().map(s -> "'" + s + "'").iterator()); + + testQuery( + "SELECT dim1, COUNT(*) FROM druid.foo\n" + + "WHERE dim1 IN (" + elementsString + ") OR dim1 = 'xyz' OR dim1 IS NULL\n" + + "GROUP BY dim1", + QueryContexts.override( + QUERY_CONTEXT_DEFAULT, + ImmutableMap.of( + QueryContexts.IN_FUNCTION_THRESHOLD, 20, + QueryContexts.IN_SUB_QUERY_THRESHOLD_KEY, 20 + ) + ), ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"))) + .setDimFilter( + // [dim1 = xyz] is not combined into the IN filter, because SCALAR_IN_ARRAY was used, + // and it is opaque to most optimizations. (That's its main purpose.) + NullHandling.sqlCompatible() + ? or( + in("dim1", elements), + isNull("dim1"), + equality("dim1", "xyz", ColumnType.STRING) + ) + : or( + in("dim1", Arrays.asList(null, "xyz")), + in("dim1", elements) + ) + ) + .setAggregatorSpecs( + aggregators( + new CountAggregatorFactory("a0") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ) + : ImmutableList.of( + new Object[]{"", 1L}, new Object[]{"abc", 1L}, new Object[]{"def", 1L} )