From f9439970c926ec832e3f14cc235a46ec86784499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20L=C3=A9aut=C3=A9?= Date: Fri, 6 Oct 2023 12:45:07 -0700 Subject: [PATCH 01/14] run build and unit tests using Java 21 (#15088) * run build and unit test using Java 21 * run static checks with Java 21 * use setup-java for unit tests, since Java 21 is not built-in * skip maven cache from setup-java * add comments to explain cache behavior --- .github/workflows/reusable-unit-tests.yml | 10 ++++++++-- .github/workflows/standard-its.yml | 2 ++ .github/workflows/static-checks.yml | 2 +- .../unit-and-integration-tests-unified.yml | 15 ++++++++++----- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/.github/workflows/reusable-unit-tests.yml b/.github/workflows/reusable-unit-tests.yml index 34d992c397c2..06a48362c404 100644 --- a/.github/workflows/reusable-unit-tests.yml +++ b/.github/workflows/reusable-unit-tests.yml @@ -59,9 +59,15 @@ jobs: with: fetch-depth: 0 - - name: setup jdk${{ inputs.jdk }} - run: echo "JAVA_HOME=$JAVA_HOME_${{ inputs.jdk }}_X64" >> $GITHUB_ENV + # skip the "cache: maven" step from setup-java. We explicitly use a + # different cache key since we cannot reuse it across commits. + - uses: actions/setup-java@v3 + with: + distribution: 'zulu' + java-version: ${{ inputs.jdk }} + # the build step produces SNAPSHOT artifacts into the local maven repository, + # we include github.sha in the cache key to make it specific to that build/jdk - name: Restore Maven repository id: maven-restore uses: actions/cache/restore@v3 diff --git a/.github/workflows/standard-its.yml b/.github/workflows/standard-its.yml index 2648dc0993b6..ae78a1f2a836 100644 --- a/.github/workflows/standard-its.yml +++ b/.github/workflows/standard-its.yml @@ -153,6 +153,8 @@ jobs: - name: Setup java run: export JAVA_HOME=$JAVA_HOME_8_X64 + # the build step produces SNAPSHOT artifacts into the local maven repository, + # we include github.sha in the cache key to make it specific to that build/jdk - name: Restore Maven repository id: maven-restore uses: actions/cache/restore@v3 diff --git a/.github/workflows/static-checks.yml b/.github/workflows/static-checks.yml index 4b1c4db0c68d..49cd516f5cae 100644 --- a/.github/workflows/static-checks.yml +++ b/.github/workflows/static-checks.yml @@ -41,7 +41,7 @@ jobs: strategy: fail-fast: false matrix: - java: [ '8', '11', '17' ] + java: [ '8', '11', '17', '21' ] runs-on: ubuntu-latest steps: - name: checkout branch diff --git a/.github/workflows/unit-and-integration-tests-unified.yml b/.github/workflows/unit-and-integration-tests-unified.yml index 6ff6c8bd6500..d834143695d3 100644 --- a/.github/workflows/unit-and-integration-tests-unified.yml +++ b/.github/workflows/unit-and-integration-tests-unified.yml @@ -56,16 +56,21 @@ jobs: strategy: fail-fast: false matrix: - jdk: [ '8', '11', '17' ] + jdk: [ '8', '11', '17', '21' ] runs-on: ubuntu-latest steps: - name: Checkout branch uses: actions/checkout@v3 - - name: setup jdk${{ matrix.jdk }} - run: | - echo "JAVA_HOME=$JAVA_HOME_${{ matrix.jdk }}_X64" >> $GITHUB_ENV + # skip the "cache: maven" step from setup-java. We explicitly use a + # different cache key since we cannot reuse it across commits. + - uses: actions/setup-java@v3 + with: + distribution: 'zulu' + java-version: ${{ matrix.jdk }} + # the build step produces SNAPSHOT artifacts into the local maven repository, + # we include github.sha in the cache key to make it specific to that build/jdk - name: Cache Maven m2 repository id: maven uses: actions/cache@v3 @@ -112,7 +117,7 @@ jobs: strategy: fail-fast: false matrix: - jdk: [ 11, 17 ] + jdk: [ 11, 17, 21 ] name: "unit tests (jdk${{ matrix.jdk }}, sql-compat=true)" uses: ./.github/workflows/unit-tests.yml needs: unit-tests From 57ab8e13dc66bae6757c922d72b555ad83c4e76d Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Fri, 6 Oct 2023 19:23:12 -0700 Subject: [PATCH 02/14] Updating plans when using joins with unnest on the left (#15075) * Updating plans when using joins with unnest on the left * Correcting segment map function for hashJoin * The changes done here are not reflected into MSQ yet so these tests might not run in MSQ * native tests * Self joins with unnest data source * Making this pass * Addressing comments by adding explanation and new test --- .../apache/druid/query/JoinDataSource.java | 64 +++- .../druid/query/JoinDataSourceTest.java | 48 +++ .../druid/sql/calcite/rel/DruidRels.java | 2 +- .../sql/calcite/CalciteJoinQueryTest.java | 355 ++++++++++++++++++ 4 files changed, 454 insertions(+), 15 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/JoinDataSource.java b/processing/src/main/java/org/apache/druid/query/JoinDataSource.java index 77734edf0252..220f18a94855 100644 --- a/processing/src/main/java/org/apache/druid/query/JoinDataSource.java +++ b/processing/src/main/java/org/apache/druid/query/JoinDataSource.java @@ -476,10 +476,25 @@ private Function createSegmentMapFunctionInt .orElse(null) ) ); - + final Function baseMapFn; + // A join data source is not concrete + // And isConcrete() of an unnest datasource delegates to its base + // Hence, in the case of a Join -> Unnest -> Join + // if we just use isConcrete on the left + // the segment map function for the unnest would never get called + // This calls us to delegate to the segmentMapFunction of the left + // only when it is not a JoinDataSource + if (left instanceof JoinDataSource) { + baseMapFn = Function.identity(); + } else { + baseMapFn = left.createSegmentMapFunction( + query, + cpuTimeAccumulator + ); + } return baseSegment -> new HashJoinSegment( - baseSegment, + baseMapFn.apply(baseSegment), baseFilterToUse, GuavaUtils.firstNonNull(clausesToUse, ImmutableList.of()), joinFilterPreAnalysis @@ -501,18 +516,39 @@ private static Triple> flattenJoi DimFilter currentDimFilter = null; final List preJoinableClauses = new ArrayList<>(); - while (current instanceof JoinDataSource) { - final JoinDataSource joinDataSource = (JoinDataSource) current; - current = joinDataSource.getLeft(); - currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter()); - preJoinableClauses.add( - new PreJoinableClause( - joinDataSource.getRightPrefix(), - joinDataSource.getRight(), - joinDataSource.getJoinType(), - joinDataSource.getConditionAnalysis() - ) - ); + // There can be queries like + // Join of Unnest of Join of Unnest of Filter + // so these checks are needed to be ORed + // to get the base + // This method is called to get the analysis for the join data source + // Since the analysis of an UnnestDS or FilteredDS always delegates to its base + // To obtain the base data source underneath a Join + // we also iterate through the base of the FilterDS and UnnestDS in its path + // the base of which can be a concrete data source + // This also means that an addition of a new datasource + // Will need an instanceof check here + // A future work should look into if the flattenJoin + // can be refactored to omit these instanceof checks + while (current instanceof JoinDataSource || current instanceof UnnestDataSource || current instanceof FilteredDataSource) { + if (current instanceof JoinDataSource) { + final JoinDataSource joinDataSource = (JoinDataSource) current; + current = joinDataSource.getLeft(); + currentDimFilter = validateLeftFilter(current, joinDataSource.getLeftFilter()); + preJoinableClauses.add( + new PreJoinableClause( + joinDataSource.getRightPrefix(), + joinDataSource.getRight(), + joinDataSource.getJoinType(), + joinDataSource.getConditionAnalysis() + ) + ); + } else if (current instanceof UnnestDataSource) { + final UnnestDataSource unnestDataSource = (UnnestDataSource) current; + current = unnestDataSource.getBase(); + } else { + final FilteredDataSource filteredDataSource = (FilteredDataSource) current; + current = filteredDataSource.getBase(); + } } // Join clauses were added in the order we saw them while traversing down, but we need to apply them in the diff --git a/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java b/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java index b23c0b92dbbd..b821bc49c4e7 100644 --- a/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java +++ b/processing/src/test/java/org/apache/druid/query/JoinDataSourceTest.java @@ -29,11 +29,14 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.TrueDimFilter; +import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinType; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.join.NoopJoinableFactory; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.easymock.Mock; import org.junit.Assert; import org.junit.Rule; @@ -433,6 +436,51 @@ public void test_computeJoinDataSourceCacheKey_keyChangesWithPrefix() Assert.assertFalse(Arrays.equals(cacheKey1, cacheKey2)); } + @Test + public void testGetAnalysisWithUnnestDS() + { + JoinDataSource dataSource = JoinDataSource.create( + UnnestDataSource.create( + new TableDataSource("table1"), + new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()), + null + ), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + null, + ExprMacroTable.nil(), + null + ); + DataSourceAnalysis analysis = dataSource.getAnalysis(); + Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next()); + } + + @Test + public void testGetAnalysisWithFilteredDS() + { + JoinDataSource dataSource = JoinDataSource.create( + UnnestDataSource.create( + FilteredDataSource.create( + new TableDataSource("table1"), + TrueDimFilter.instance() + ), + new ExpressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING, ExprMacroTable.nil()), + null + ), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + null, + ExprMacroTable.nil(), + null + ); + DataSourceAnalysis analysis = dataSource.getAnalysis(); + Assert.assertEquals("table1", analysis.getBaseDataSource().getTableNames().iterator().next()); + } + @Test public void test_computeJoinDataSourceCacheKey_keyChangesWithBaseFilter() { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java index c35c872544f1..1627329c75e4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java @@ -66,7 +66,7 @@ public static boolean isScanOrMapping(final DruidRel druidRel, final boolean */ public static boolean isScanOrProject(final DruidRel druidRel, final boolean canBeJoinOrUnion) { - if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel + if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel || druidRel instanceof DruidCorrelateUnnestRel || druidRel instanceof DruidUnionDataSourceRel))) { final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery(); final PartialDruidQuery.Stage stage = partialQuery.stage(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index d1300ff19b24..e8a728339605 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -38,6 +38,7 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.DataSource; import org.apache.druid.query.Druids; +import org.apache.druid.query.FilteredDataSource; import org.apache.druid.query.GlobalTableDataSource; import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.JoinDataSource; @@ -49,6 +50,7 @@ import org.apache.druid.query.QueryException; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; @@ -64,6 +66,7 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.extraction.SubstringDimExtractionFn; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.ResultRow; @@ -5914,4 +5917,356 @@ public void testJoinWithInputRefCondition() ) ); } + + @Test + public void testJoinsWithUnnestOnLeft() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3)\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "_j0.", + "(\"j0.unnest\" == \"_j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("_j0.dim2", "dim3", "j0.unnest") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"", "", ""} + ) + ); + } + + @Test + public void testJoinsWithUnnestOverFilteredDSOnLeft() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) where dim2='a'\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + FilteredDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + equality("dim2", "a", ColumnType.STRING) + ), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "_j0.", + "(\"j0.unnest\" == \"_j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("_j0.dim2", "dim3", "j0.unnest") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"", "", ""} + ) + ); + } + + @Test + public void testJoinsWithUnnestOverJoin() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from (SELECT * from foo JOIN (select dim2 as t from foo where dim2 IN ('a','b','ab','abc')) ON dim2=t), " + + " unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) \n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE1) + .filters(new InDimFilter("dim2", ImmutableList.of("a", "b", "ab", "abc"), null)) + .legacy(false) + .context(context) + .columns("dim2") + .build() + ), + "j0.", + "(\"dim2\" == \"j0.dim2\")", + JoinType.INNER + ), + expressionVirtualColumn("_j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "__j0.", + "(\"_j0.unnest\" == \"__j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("__j0.dim2", "_j0.unnest", "dim3") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"", "", ""}, + new Object[]{"", "", ""}, + new Object[]{"", "", ""}, + new Object[]{"", "", ""} + ) + ); + } + + @Test + public void testSelfJoinsWithUnnestOnLeftAndRight() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from foo, unnest(MV_TO_ARRAY(\"dim3\")) as u(d3)\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2 from t1 JOIN t1 as t2\n" + + "ON t1.d3 = t2.d3", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(UnnestDataSource.create( + new TableDataSource(CalciteTests.DATASOURCE1), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + )) + .columns("dim2", "j0.unnest") + .legacy(false) + .context(context) + .build() + ), + "_j0.", + "(\"j0.unnest\" == \"_j0.j0.unnest\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("_j0.dim2", "dim3", "j0.unnest") + .context(context) + .build() + ), + useDefault ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", ""}, + new Object[]{"[\"b\",\"c\"]", "b", "a"}, + new Object[]{"[\"b\",\"c\"]", "b", ""}, + new Object[]{"[\"b\",\"c\"]", "c", ""}, + new Object[]{"d", "d", ""} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", "a"}, + new Object[]{"[\"a\",\"b\"]", "b", null}, + new Object[]{"[\"b\",\"c\"]", "b", "a"}, + new Object[]{"[\"b\",\"c\"]", "b", null}, + new Object[]{"[\"b\",\"c\"]", "c", null}, + new Object[]{"d", "d", ""}, + new Object[]{"", "", "a"} + ) + ); + } + + @Test + public void testJoinsOverUnnestOverFilterDSOverJoin() + { + // Segment map function of MSQ needs some work + // To handle these nested cases + // Remove this when that's handled + msqIncompatible(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "with t1 as (\n" + + "select * from (SELECT * from foo JOIN (select dim2 as t from foo where dim2 IN ('a','b','ab','abc')) ON dim2=t),\n" + + "unnest(MV_TO_ARRAY(\"dim3\")) as u(d3) where m1 IN (1,4) and d3='a'\n" + + ")\n" + + "select t1.dim3, t1.d3, t2.dim2, t1.m1 from t1 JOIN numfoo as t2\n" + + "ON t1.d3 = t2.\"dim2\"", + context, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + UnnestDataSource.create( + FilteredDataSource.create( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE1) + .columns("dim2") + .filters(new InDimFilter( + "dim2", + ImmutableList.of("a", "ab", "abc", "b"), + null + )) + .legacy(false) + .context(context) + .build() + ), + "j0.", + "(\"dim2\" == \"j0.dim2\")", + JoinType.INNER + ), + useDefault ? + new InDimFilter("m1", ImmutableList.of("1", "4"), null) : + or( + equality("m1", 1.0, ColumnType.FLOAT), + equality("m1", 4.0, ColumnType.FLOAT) + ) + ), + expressionVirtualColumn("_j0.unnest", "\"dim3\"", ColumnType.STRING), + equality("_j0.unnest", "a", ColumnType.STRING) + ), + new QueryDataSource( + newScanQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .dataSource(CalciteTests.DATASOURCE3) + .columns("dim2") + .legacy(false) + .context(context) + .build() + ), + "__j0.", + "(\"_j0.unnest\" == \"__j0.dim2\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("__j0.dim2", "_j0.unnest", "dim3", "m1") + .context(context) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f}, + new Object[]{"[\"a\",\"b\"]", "a", "a", 1.0f} + ) + ); + } } From 7b869fd37a90484bd2f4e1685d356018b5cfd829 Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Sat, 7 Oct 2023 14:31:09 +0200 Subject: [PATCH 03/14] Change type of AVG aggregates to double (#15089) The sql standard is not very restrictive regarding this: If AVG is specified and DT is exact numeric, then the declared type of the result is an implemen- tation-defined exact numeric type with precision not less than the precision of DT and scale not less than the scale of DT. so; using the same type is also ok (without patch); however the avg of 0 and 1 is 0 right now because of the retention of the integer typ Postgres,MySql and Oracle and Drill seem to increase precision ; mssql returns 0 http://sqlfiddle.com/#!9/6f7248/1 I think we should also increase precision as its already calculated more precisely --- .../hll/sql/HllSketchSqlAggregatorTest.java | 6 ++-- .../sql/ThetaSketchSqlAggregatorTest.java | 6 ++-- .../sql/BaseVarianceSqlAggregator.java | 22 +++++++------- .../sql/VarianceSqlAggregatorTest.java | 20 ++++++------- .../sql/calcite/planner/DruidTypeSystem.java | 8 +---- .../calcite/CalciteCorrelatedQueryTest.java | 13 ++++---- .../calcite/CalciteParameterQueryTest.java | 2 +- .../druid/sql/calcite/CalciteQueryTest.java | 30 +++++++++---------- .../sql/calcite/CalciteSubqueryTest.java | 7 +++-- 9 files changed, 56 insertions(+), 58 deletions(-) diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 349f1a57d1c0..20ea97aec3e5 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -369,7 +369,7 @@ public void testAvgDailyCountDistinctHllSketch() final List expectedResults = ImmutableList.of( new Object[]{ - 1L + 1.0 } ); @@ -429,11 +429,11 @@ public void testAvgDailyCountDistinctHllSketch() .setAggregatorSpecs( NullHandling.replaceWithDefault() ? Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new CountAggregatorFactory("_a0:count") ) : Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a0:count"), notNull("a0") diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java index 3946ce558b19..3a079e064783 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java @@ -278,7 +278,7 @@ public void testAvgDailyCountDistinctThetaSketch() final List expectedResults = ImmutableList.of( new Object[]{ - 1L + 1.0 } ); @@ -334,11 +334,11 @@ public void testAvgDailyCountDistinctThetaSketch() .setAggregatorSpecs( NullHandling.replaceWithDefault() ? Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new CountAggregatorFactory("_a0:count") ) : Arrays.asList( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a0:count"), notNull("a0") diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index 0b1562eb83d1..ee8c469c3b86 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -30,6 +30,7 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator private static final String STDDEV_NAME = "STDDEV"; private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(VARIANCE_NAME); + buildSqlVarianceAggFunction(VARIANCE_NAME); private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.VAR_POP.name()); + buildSqlVarianceAggFunction(SqlKind.VAR_POP.name()); private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name()); + buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name()); private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(STDDEV_NAME); + buildSqlVarianceAggFunction(STDDEV_NAME); private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name()); + buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name()); private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name()); + buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name()); @Nullable @Override @@ -160,14 +161,15 @@ public Aggregation toDruidAggregation( } /** - * Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction} - * but with an operand type that accepts variance aggregator objects in addition to numeric inputs. + * Creates a {@link SqlAggFunction} + * + * It accepts variance aggregator objects in addition to numeric inputs. */ - private static SqlAggFunction buildSqlAvgAggFunction(String name) + private static SqlAggFunction buildSqlVarianceAggFunction(String name) { return OperatorConversions .aggregatorBuilder(name) - .returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION) + .returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE)) .operandTypeChecker( OperandTypes.or( OperandTypes.NUMERIC, diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index fe68b2737ef3..e45a93784967 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -171,8 +171,8 @@ public void testVarPop() final List expectedResults = ImmutableList.of( new Object[]{ holder1.getVariance(true), - holder2.getVariance(true).doubleValue(), - holder3.getVariance(true).longValue() + holder2.getVariance(true), + holder3.getVariance(true) } ); testQuery( @@ -219,7 +219,7 @@ public void testVarSamp() new Object[] { holder1.getVariance(false), holder2.getVariance(false).doubleValue(), - holder3.getVariance(false).longValue(), + holder3.getVariance(false), } ); testQuery( @@ -266,7 +266,7 @@ public void testStdDevPop() new Object[] { Math.sqrt(holder1.getVariance(true)), Math.sqrt(holder2.getVariance(true)), - (long) Math.sqrt(holder3.getVariance(true)), + Math.sqrt(holder3.getVariance(true)), } ); @@ -321,7 +321,7 @@ public void testStdDevSamp() new Object[]{ Math.sqrt(holder1.getVariance(false)), Math.sqrt(holder2.getVariance(false)), - (long) Math.sqrt(holder3.getVariance(false)), + Math.sqrt(holder3.getVariance(false)), } ); @@ -374,7 +374,7 @@ public void testStdDevWithVirtualColumns() new Object[]{ Math.sqrt(holder1.getVariance(false)), Math.sqrt(holder2.getVariance(false)), - (long) Math.sqrt(holder3.getVariance(false)), + Math.sqrt(holder3.getVariance(false)), } ); @@ -543,7 +543,7 @@ public void testEmptyTimeseriesResults() ), ImmutableList.of( NullHandling.replaceWithDefault() - ? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L} + ? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} : new Object[]{null, null, null, null, null, null, null, null} ) ); @@ -623,7 +623,7 @@ public void testGroupByAggregatorDefaultValues() ), ImmutableList.of( NullHandling.replaceWithDefault() - ? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L} + ? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} : new Object[]{"a", null, null, null, null, null, null, null, null} ) ); @@ -688,9 +688,9 @@ public void assertResultsEquals(String sql, List expectedResults, List Assert.assertEquals(expectedResult.length, result.length); for (int j = 0; j < expectedResult.length; j++) { if (expectedResult[j] instanceof Float) { - Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10); + Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5); } else if (expectedResult[j] instanceof Double) { - Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10); + Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5); } else { Assert.assertEquals(expectedResult[j], result[j]); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java index d3d09f7bdf36..dcba20ee6c46 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidTypeSystem.java @@ -124,13 +124,7 @@ public RelDataType deriveAvgAggType( final RelDataType argumentType ) { - // Widen all averages to 64-bits regardless of the size of the inputs. - - if (SqlTypeName.INT_TYPES.contains(argumentType.getSqlTypeName())) { - return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.BIGINT, argumentType.isNullable()); - } else { - return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable()); - } + return Calcites.createSqlTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, argumentType.isNullable()); } @Override diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java index 89b09872d402..a7a5222d8889 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteCorrelatedQueryTest.java @@ -29,9 +29,10 @@ import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; -import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory; import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory; import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory; import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator; @@ -127,7 +128,7 @@ public void testCorrelatedSubquery(Map queryContext) .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) .setDimensions(new DefaultDimensionSpec("d1", "_d0")) .setAggregatorSpecs( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), useDefault ? new CountAggregatorFactory("_a0:count") : new FilteredAggregatorFactory( @@ -158,15 +159,15 @@ public void testCorrelatedSubquery(Map queryContext) ) .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) .setDimensions(new DefaultDimensionSpec("country", "d0")) - .setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0")) + .setAggregatorSpecs(new DoubleAnyAggregatorFactory("a0", "j0._a0")) .setGranularity(new AllGranularity()) .setContext(queryContext) .build() ), ImmutableList.of( - new Object[]{"India", 2L}, - new Object[]{"USA", 1L}, - new Object[]{"canada", 3L} + new Object[]{"India", 2.0}, + new Object[]{"USA", 1.0}, + new Object[]{"canada", 3.0} ) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java index 72687eb3a196..a1438824b40f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java @@ -221,7 +221,7 @@ public void testParamsInInformationSchema() + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", ImmutableList.of(), ImmutableList.of( - new Object[]{8L, 1249L, 156L, -5L, 1111L} + new Object[]{8L, 1249L, 156.125, -5L, 1111L} ), ImmutableList.of( new SqlParameter(SqlType.VARCHAR, "druid"), diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index da3e3f21b090..4042def2750d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -374,7 +374,7 @@ public void testAggregatorsOnInformationSchemaColumns() + "WHERE TABLE_SCHEMA = 'druid' AND TABLE_NAME = 'foo'", ImmutableList.of(), ImmutableList.of( - new Object[]{8L, 1249L, 156L, -5L, 1111L} + new Object[]{8L, 1249L, 156.125, -5L, 1111L} ) ); } @@ -4942,7 +4942,7 @@ public void testSimpleAggregations() new CountAggregatorFactory("a1"), notNull("dim1") ), - new LongSumAggregatorFactory("a2:sum", "cnt"), + new DoubleSumAggregatorFactory("a2:sum", "cnt"), new CountAggregatorFactory("a2:count"), new LongSumAggregatorFactory("a3", "cnt"), new LongMinAggregatorFactory("a4", "cnt"), @@ -4964,7 +4964,7 @@ public void testSimpleAggregations() new CountAggregatorFactory("a2"), notNull("dim1") ), - new LongSumAggregatorFactory("a3:sum", "cnt"), + new DoubleSumAggregatorFactory("a3:sum", "cnt"), new FilteredAggregatorFactory( new CountAggregatorFactory("a3:count"), notNull("cnt") @@ -5014,10 +5014,10 @@ public void testSimpleAggregations() ), NullHandling.replaceWithDefault() ? ImmutableList.of( - new Object[]{6L, 6L, 5L, 1L, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)} + new Object[]{6L, 6L, 5L, 1.0, 6L, 8L, 3L, 6L, ((1 + 1.7) / 6)} ) : ImmutableList.of( - new Object[]{6L, 6L, 6L, 1L, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)} + new Object[]{6L, 6L, 6L, 1.0, 6L, 8L, 4L, 3L, ((1 + 1.7) / 3)} ) ); } @@ -7429,11 +7429,11 @@ public void testAvgDailyCountDistinct() .setAggregatorSpecs( useDefault ? aggregators( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new CountAggregatorFactory("_a0:count") ) : aggregators( - new LongSumAggregatorFactory("_a0:sum", "a0"), + new DoubleSumAggregatorFactory("_a0:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a0:count"), notNull("a0") @@ -7455,7 +7455,7 @@ public void testAvgDailyCountDistinct() .setContext(QUERY_CONTEXT_DEFAULT) .build() ), - ImmutableList.of(new Object[]{1L}) + ImmutableList.of(new Object[]{1.0}) ); } @@ -9641,7 +9641,7 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValues() new LongSumAggregatorFactory("a6", "l1"), new LongMaxAggregatorFactory("a7", "l1"), new LongMinAggregatorFactory("a8", "l1"), - new LongSumAggregatorFactory("a9:sum", "l1"), + new DoubleSumAggregatorFactory("a9:sum", "l1"), useDefault ? new CountAggregatorFactory("a9:count") : new FilteredAggregatorFactory( @@ -9690,7 +9690,7 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValues() 0L, Long.MIN_VALUE, Long.MAX_VALUE, - 0L, + Double.NaN, Double.NaN } : new Object[]{0L, 0L, 0L, null, null, null, null, null, null, null, null} @@ -9936,7 +9936,7 @@ public void testGroupByAggregatorDefaultValues() equality("dim1", "nonexistent", ColumnType.STRING) ), new FilteredAggregatorFactory( - new LongSumAggregatorFactory("a9:sum", "l1"), + new DoubleSumAggregatorFactory("a9:sum", "l1"), equality("dim1", "nonexistent", ColumnType.STRING) ), useDefault @@ -10005,7 +10005,7 @@ public void testGroupByAggregatorDefaultValues() 0L, Long.MIN_VALUE, Long.MAX_VALUE, - 0L, + Double.NaN, Double.NaN } : new Object[]{"a", 0L, 0L, 0L, null, null, null, null, null, null, null, null} @@ -13147,7 +13147,7 @@ public void testCountAndAverageByConstantVirtualColumn() new CountAggregatorFactory("a0"), notNull("v0") ), - new LongSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE), + new DoubleSumAggregatorFactory("a1:sum", "v1", null, TestExprMacroTable.INSTANCE), new CountAggregatorFactory("a1:count") ); virtualColumns = ImmutableList.of( @@ -13160,7 +13160,7 @@ public void testCountAndAverageByConstantVirtualColumn() new CountAggregatorFactory("a0"), notNull("v0") ), - new LongSumAggregatorFactory("a1:sum", "v1"), + new DoubleSumAggregatorFactory("a1:sum", "v1"), new FilteredAggregatorFactory( new CountAggregatorFactory("a1:count"), notNull("v1") @@ -13204,7 +13204,7 @@ public void testCountAndAverageByConstantVirtualColumn() .build() ), ImmutableList.of( - new Object[]{"ab", 1L, 325323L} + new Object[]{"ab", 1L, 325323.0} ) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java index 2ddc674eadda..fb4c61b8cec3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSubqueryTest.java @@ -34,6 +34,7 @@ import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; import org.apache.druid.query.aggregation.LongMinAggregatorFactory; @@ -558,14 +559,14 @@ public void testMinMaxAvgDailyCountWithLimit() aggregators( new LongMaxAggregatorFactory("_a0", "a0"), new LongMinAggregatorFactory("_a1", "a0"), - new LongSumAggregatorFactory("_a2:sum", "a0"), + new DoubleSumAggregatorFactory("_a2:sum", "a0"), new CountAggregatorFactory("_a2:count"), new LongMaxAggregatorFactory("_a3", "d0"), new CountAggregatorFactory("_a4") ) : aggregators( new LongMaxAggregatorFactory("_a0", "a0"), new LongMinAggregatorFactory("_a1", "a0"), - new LongSumAggregatorFactory("_a2:sum", "a0"), + new DoubleSumAggregatorFactory("_a2:sum", "a0"), new FilteredAggregatorFactory( new CountAggregatorFactory("_a2:count"), notNull("a0") @@ -590,7 +591,7 @@ public void testMinMaxAvgDailyCountWithLimit() .setContext(queryContext) .build() ), - ImmutableList.of(new Object[]{1L, 1L, 1L, 978480000L, 6L}) + ImmutableList.of(new Object[]{1L, 1L, 1.0, 978480000L, 6L}) ); } From b5a87fd89bc8480702c90cc1b678d20f01c66426 Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Sun, 8 Oct 2023 08:44:25 +0200 Subject: [PATCH 04/14] Support constant args in window functions (#15071) Instead of passing the constants around in a new parameter; InputAccessor was introduced to take care of transparently handling the constants - this new class started picking up some copy-paste debris around field accesses; and made them a little bit more readble. --- ...CompressedBigDecimalSqlAggregatorBase.java | 38 ++---- .../TDigestGenerateSketchSqlAggregator.java | 25 +--- .../TDigestSketchQuantileSqlAggregator.java | 32 +----- .../hll/sql/HllSketchBaseSqlAggregator.java | 38 ++---- ...blesSketchApproxQuantileSqlAggregator.java | 32 +----- .../sql/DoublesSketchObjectSqlAggregator.java | 25 +--- .../sql/ThetaSketchBaseSqlAggregator.java | 31 ++--- .../ArrayOfDoublesSketchSqlAggregator.java | 31 ++--- .../bloom/sql/BloomFilterSqlAggregator.java | 24 +--- ...BucketsHistogramQuantileSqlAggregator.java | 53 ++------- .../histogram/sql/QuantileSqlAggregator.java | 37 ++---- .../sql/BaseVarianceSqlAggregator.java | 19 +-- .../sql/calcite/aggregation/Aggregations.java | 15 +-- .../ApproxCountDistinctSqlAggregator.java | 12 +- .../calcite/aggregation/SqlAggregator.java | 46 +++++++- .../builtin/ArrayConcatSqlAggregator.java | 17 +-- .../builtin/ArraySqlAggregator.java | 18 +-- .../aggregation/builtin/AvgSqlAggregator.java | 27 ++--- .../builtin/BitwiseSqlAggregator.java | 12 +- ...iltinApproxCountDistinctSqlAggregator.java | 23 ++-- .../builtin/CountSqlAggregator.java | 30 ++--- .../EarliestLatestAnySqlAggregator.java | 20 +--- .../EarliestLatestBySqlAggregator.java | 17 +-- .../builtin/GroupingSqlAggregator.java | 16 ++- .../builtin/LiteralSqlAggregator.java | 10 +- .../builtin/SimpleSqlAggregator.java | 12 +- .../builtin/StringSqlAggregator.java | 26 +---- .../expression/WindowSqlAggregate.java | 8 +- .../druid/sql/calcite/rel/DruidQuery.java | 6 +- .../druid/sql/calcite/rel/InputAccessor.java | 108 ++++++++++++++++++ .../druid/sql/calcite/rel/Windowing.java | 6 +- .../druid/sql/calcite/rule/GroupByRules.java | 14 +-- .../tests/window/aggregateConstant.sqlTest | 26 +++++ 33 files changed, 357 insertions(+), 497 deletions(-) create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java create mode 100644 sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest diff --git a/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java b/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java index 4a61f0271eeb..a6c23551598e 100644 --- a/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java +++ b/extensions-contrib/compressed-bigdecimal/src/main/java/org/apache/druid/compressedbigdecimal/CompressedBigDecimalSqlAggregatorBase.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -36,12 +34,12 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -71,12 +69,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -88,13 +84,8 @@ public Aggregation toDruidAggregation( // fetch sum column expression DruidExpression sumColumn = Expressions.toDruidExpression( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (sumColumn == null) { @@ -114,12 +105,7 @@ public Aggregation toDruidAggregation( Integer size = null; if (aggregateCall.getArgList().size() >= 2) { - RexNode sizeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + RexNode sizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); size = ((Number) RexLiteral.value(sizeArg)).intValue(); } @@ -128,12 +114,7 @@ public Aggregation toDruidAggregation( Integer scale = null; if (aggregateCall.getArgList().size() >= 3) { - RexNode scaleArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + RexNode scaleArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); scale = ((Number) RexLiteral.value(scaleArg)).intValue(); } @@ -141,12 +122,7 @@ public Aggregation toDruidAggregation( Boolean useStrictNumberParsing = null; if (aggregateCall.getArgList().size() >= 4) { - RexNode useStrictNumberParsingArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(3) - ); + RexNode useStrictNumberParsingArg = inputAccessor.getField(aggregateCall.getArgList().get(3)); useStrictNumberParsing = RexLiteral.booleanValue(useStrictNumberParsingArg); } diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java index ca0a4acc603f..ebb6c7f4b141 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.query.aggregation.tdigestsketch.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -36,13 +34,12 @@ import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -63,25 +60,18 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { - final RexNode inputOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0)); final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), inputOperand ); if (input == null) { @@ -93,12 +83,7 @@ public Aggregation toDruidAggregation( Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION; if (aggregateCall.getArgList().size() > 1) { - RexNode compressionOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + RexNode compressionOperand = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!compressionOperand.isA(SqlKind.LITERAL)) { // compressionOperand must be a literal in order to plan. return null; diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java index 379e889d3835..ee63444f6d71 100644 --- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java +++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -39,13 +37,12 @@ import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchToQuantilePostAggregator; import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchUtils; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -66,12 +63,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) @@ -79,13 +74,8 @@ public Aggregation toDruidAggregation( // This is expected to be a tdigest sketch final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -95,12 +85,7 @@ public Aggregation toDruidAggregation( final String sketchName = StringUtils.format("%s:agg", name); // this is expected to be quantile fraction - final RexNode quantileArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode quantileArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!quantileArg.isA(SqlKind.LITERAL)) { // Quantile must be a literal in order to plan. @@ -110,12 +95,7 @@ public Aggregation toDruidAggregation( final double quantile = ((Number) RexLiteral.value(quantileArg)).floatValue(); Integer compression = TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION; if (aggregateCall.getArgList().size() > 2) { - final RexNode compressionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode compressionArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); compression = ((Number) RexLiteral.value(compressionArg)).intValue(); } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java index c6dd3e7afa02..d221b72ac1c6 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.hll.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; @@ -36,7 +34,6 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -44,6 +41,7 @@ import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -66,38 +64,26 @@ protected HllSketchBaseSqlAggregator(boolean finalizeSketch, StringEncoding stri @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // for string columns. - final RexNode columnRexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0)); - final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode); + final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode); if (columnArg == null) { return null; } final int logK; if (aggregateCall.getArgList().size() >= 2) { - final RexNode logKarg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode logKarg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!logKarg.isA(SqlKind.LITERAL)) { // logK must be a literal in order to plan. @@ -111,12 +97,7 @@ public Aggregation toDruidAggregation( final String tgtHllType; if (aggregateCall.getArgList().size() >= 3) { - final RexNode tgtHllTypeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode tgtHllTypeArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!tgtHllTypeArg.isA(SqlKind.LITERAL)) { // tgtHllType must be a literal in order to plan. @@ -132,9 +113,10 @@ public Aggregation toDruidAggregation( final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; if (columnArg.isDirectColumnAccess() - && rowSignature.getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) - .orElse(false)) { + && inputAccessor.getInputRowSignature() + .getColumnType(columnArg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new HllSketchMergeAggregatorFactory( aggregatorName, columnArg.getDirectColumn(), diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java index 6c1b5720af49..08c7a1b123fd 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -37,14 +35,13 @@ import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchToQuantilePostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -75,25 +72,18 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -101,12 +91,7 @@ public Aggregation toDruidAggregation( final AggregatorFactory aggregatorFactory; final String histogramName = StringUtils.format("%s:agg", name); - final RexNode probabilityArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!probabilityArg.isA(SqlKind.LITERAL)) { // Probability must be a literal in order to plan. @@ -117,12 +102,7 @@ public Aggregation toDruidAggregation( final int k; if (aggregateCall.getArgList().size() >= 3) { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!resolutionArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java index 049e1284a911..8331ab720640 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -35,14 +33,13 @@ import org.apache.druid.query.aggregation.datasketches.SketchQueryContext; import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAggregatorFactory; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -71,25 +68,18 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -100,12 +90,7 @@ public Aggregation toDruidAggregation( final int k; if (aggregateCall.getArgList().size() >= 2) { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!resolutionArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java index 6564b276c971..bf35cd665ae8 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.theta.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; @@ -34,7 +32,6 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -42,6 +39,7 @@ import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -60,38 +58,26 @@ protected ThetaSketchBaseSqlAggregator(boolean finalizeSketch) @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // for string columns. - final RexNode columnRexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode columnRexNode = inputAccessor.getField(aggregateCall.getArgList().get(0)); - final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, rowSignature, columnRexNode); + final DruidExpression columnArg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), columnRexNode); if (columnArg == null) { return null; } final int sketchSize; if (aggregateCall.getArgList().size() >= 2) { - final RexNode sketchSizeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode sketchSizeArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!sketchSizeArg.isA(SqlKind.LITERAL)) { // logK must be a literal in order to plan. @@ -107,9 +93,10 @@ public Aggregation toDruidAggregation( final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; if (columnArg.isDirectColumnAccess() - && rowSignature.getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) - .orElse(false)) { + && inputAccessor.getInputRowSignature() + .getColumnType(columnArg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new SketchMergeAggregatorFactory( aggregatorName, columnArg.getDirectColumn(), diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java index 9d6ddac89a89..a9b1aaa627d0 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/tuple/sql/ArrayOfDoublesSketchSqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.query.aggregation.datasketches.tuple.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,7 +36,6 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -46,6 +43,7 @@ import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -69,12 +67,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -86,12 +82,7 @@ public Aggregation toDruidAggregation( final int nominalEntries; final int metricExpressionEndIndex; final int lastArgIndex = argList.size() - 1; - final RexNode potentialNominalEntriesArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - argList.get(lastArgIndex) - ); + final RexNode potentialNominalEntriesArg = inputAccessor.getField(argList.get(lastArgIndex)); if (potentialNominalEntriesArg.isA(SqlKind.LITERAL) && RexLiteral.value(potentialNominalEntriesArg) instanceof Number) { @@ -107,16 +98,11 @@ public Aggregation toDruidAggregation( for (int i = 0; i <= metricExpressionEndIndex; i++) { final String fieldName; - final RexNode columnRexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - argList.get(i) - ); + final RexNode columnRexNode = inputAccessor.getField(argList.get(i)); final DruidExpression columnArg = Expressions.toDruidExpression( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), columnRexNode ); if (columnArg == null) { @@ -124,9 +110,10 @@ public Aggregation toDruidAggregation( } if (columnArg.isDirectColumnAccess() && - rowSignature.getColumnType(columnArg.getDirectColumn()) - .map(type -> type.is(ValueType.COMPLEX)) - .orElse(false)) { + inputAccessor.getInputRowSignature() + .getColumnType(columnArg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { fieldName = columnArg.getDirectColumn(); } else { final RelDataType dataType = columnRexNode.getType(); diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java index 0ec265595e11..6a1ca49067e7 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.query.aggregation.bloom.sql; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,13 +36,13 @@ import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -65,25 +63,18 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final RexNode inputOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0)); final DruidExpression input = Expressions.toDruidExpression( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), inputOperand ); if (input == null) { @@ -92,12 +83,7 @@ public Aggregation toDruidAggregation( final AggregatorFactory aggregatorFactory; final String aggName = StringUtils.format("%s:agg", name); - final RexNode maxNumEntriesOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode maxNumEntriesOperand = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!maxNumEntriesOperand.isA(SqlKind.LITERAL)) { // maxNumEntriesOperand must be a literal in order to plan. diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java index 3f0bd14f8449..fdc61796c4d9 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,13 +36,12 @@ import org.apache.druid.query.aggregation.histogram.FixedBucketsHistogramAggregatorFactory; import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -65,25 +62,18 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -91,12 +81,7 @@ public Aggregation toDruidAggregation( final AggregatorFactory aggregatorFactory; final String histogramName = StringUtils.format("%s:agg", name); - final RexNode probabilityArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!probabilityArg.isA(SqlKind.LITERAL)) { // Probability must be a literal in order to plan. @@ -107,12 +92,7 @@ public Aggregation toDruidAggregation( final int numBuckets; if (aggregateCall.getArgList().size() >= 3) { - final RexNode numBucketsArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode numBucketsArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!numBucketsArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -126,12 +106,7 @@ public Aggregation toDruidAggregation( final double lowerLimit; if (aggregateCall.getArgList().size() >= 4) { - final RexNode lowerLimitArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(3) - ); + final RexNode lowerLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(3)); if (!lowerLimitArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -145,12 +120,7 @@ public Aggregation toDruidAggregation( final double upperLimit; if (aggregateCall.getArgList().size() >= 5) { - final RexNode upperLimitArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(4) - ); + final RexNode upperLimitArg = inputAccessor.getField(aggregateCall.getArgList().get(4)); if (!upperLimitArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -164,12 +134,7 @@ public Aggregation toDruidAggregation( final FixedBucketsHistogram.OutlierHandlingMode outlierHandlingMode; if (aggregateCall.getArgList().size() >= 6) { - final RexNode outlierHandlingModeArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(5) - ); + final RexNode outlierHandlingModeArg = inputAccessor.getField(aggregateCall.getArgList().get(5)); if (!outlierHandlingModeArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index 9ba7604d7095..a3fe8dc5458a 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -39,14 +37,13 @@ import org.apache.druid.query.aggregation.histogram.ApproximateHistogramFoldingAggregatorFactory; import org.apache.druid.query.aggregation.histogram.QuantilePostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -67,25 +64,18 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ) + inputAccessor.getInputRowSignature(), + inputAccessor.getField(aggregateCall.getArgList().get(0)) ); if (input == null) { return null; @@ -93,12 +83,7 @@ public Aggregation toDruidAggregation( final AggregatorFactory aggregatorFactory; final String histogramName = StringUtils.format("%s:agg", name); - final RexNode probabilityArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + final RexNode probabilityArg = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!probabilityArg.isA(SqlKind.LITERAL)) { // Probability must be a literal in order to plan. @@ -109,12 +94,7 @@ public Aggregation toDruidAggregation( final int resolution; if (aggregateCall.getArgList().size() >= 3) { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + final RexNode resolutionArg = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!resolutionArg.isA(SqlKind.LITERAL)) { // Resolution must be a literal in order to plan. @@ -170,7 +150,10 @@ public Aggregation toDruidAggregation( // No existing match found. Create a new one. if (input.isDirectColumnAccess()) { - if (rowSignature.getColumnType(input.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) { + if (inputAccessor.getInputRowSignature() + .getColumnType(input.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory( histogramName, input.getDirectColumn(), diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index ee8c469c3b86..b2ed565d6276 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -21,9 +21,7 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; @@ -40,15 +38,14 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.table.RowSignatures; @@ -77,25 +74,19 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final RexNode inputOperand = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(0) - ); + final RexNode inputOperand = inputAccessor.getField(aggregateCall.getArgList().get(0)); + final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator( plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), inputOperand ); if (input == null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java index 3a3e43dd7b8a..5c06332a9bc4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregations.java @@ -20,14 +20,13 @@ package org.apache.druid.sql.calcite.aggregation; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import javax.annotation.Nullable; import java.util.List; @@ -48,28 +47,24 @@ private Aggregations() * * 1) They can take direct field accesses or expressions as inputs. * 2) They cannot implicitly cast strings to numbers when using a direct field access. - * * @param plannerContext SQL planner context - * @param rowSignature input row signature * @param call aggregate call object - * @param project project that should be applied before aggregation; may be null + * @param inputAccessor gives access to input fields and schema * * @return list of expressions corresponding to aggregator arguments, or null if any cannot be translated */ @Nullable public static List getArgumentsForSimpleAggregator( - final RexBuilder rexBuilder, final PlannerContext plannerContext, - final RowSignature rowSignature, final AggregateCall call, - @Nullable final Project project + final InputAccessor inputAccessor ) { final List args = call .getArgList() .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, rowSignature, rexNode)) + .map(i -> inputAccessor.getField(i)) + .map(rexNode -> toDruidExpressionForNumericAggregator(plannerContext, inputAccessor.getInputRowSignature(), rexNode)) .collect(Collectors.toList()); if (args.stream().noneMatch(Objects::isNull)) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java index 0ff7972657ea..eceb4ebbf800 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/ApproxCountDistinctSqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.sql.calcite.aggregation; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; @@ -30,8 +28,8 @@ import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.Optionality; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -66,24 +64,20 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { return delegate.toDruidAggregation( plannerContext, - rowSignature, virtualColumnRegistry, - rexBuilder, name, aggregateCall, - project, + inputAccessor, existingAggregations, finalizeAggregations ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java index d21f6ebb75ad..ec494a2fec45 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/SqlAggregator.java @@ -25,6 +25,7 @@ import org.apache.calcite.sql.SqlAggFunction; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -42,6 +43,44 @@ public interface SqlAggregator */ SqlAggFunction calciteFunction(); + /** + * Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters; + * they will be applied to your aggregator in a later step. + * + * @param plannerContext SQL planner context + * @param virtualColumnRegistry re-usable virtual column references + * @param name desired output name of the aggregation + * @param aggregateCall aggregate call object + * @param inputAccessor gives access to input fields and schema + * @param existingAggregations existing aggregations for this query; useful for re-using aggregations. May be safely + * ignored if you do not want to re-use existing aggregations. + * @param finalizeAggregations true if this query should include explicit finalization for all of its + * aggregators, where required. This is set for subqueries where Druid's native query + * layer does not do this automatically. + * @return aggregation, or null if the call cannot be translated + */ + @Nullable + default Aggregation toDruidAggregation( + PlannerContext plannerContext, + VirtualColumnRegistry virtualColumnRegistry, + String name, + AggregateCall aggregateCall, + InputAccessor inputAccessor, + List existingAggregations, + boolean finalizeAggregations + ) + { + return toDruidAggregation(plannerContext, + inputAccessor.getInputRowSignature(), + virtualColumnRegistry, + inputAccessor.getRexBuilder(), + name, + aggregateCall, + inputAccessor.getProject(), + existingAggregations, + finalizeAggregations); + } + /** * Returns a Druid Aggregation corresponding to a SQL {@link AggregateCall}. This method should ignore filters; * they will be applied to your aggregator in a later step. @@ -62,7 +101,7 @@ public interface SqlAggregator * @return aggregation, or null if the call cannot be translated */ @Nullable - Aggregation toDruidAggregation( + default Aggregation toDruidAggregation( PlannerContext plannerContext, RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, @@ -72,5 +111,8 @@ Aggregation toDruidAggregation( Project project, List existingAggregations, boolean finalizeAggregations - ); + ) + { + throw new RuntimeException("unimplemented fallback method!"); + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java index ed6652181eb4..be21701d1eb8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -39,18 +37,17 @@ import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.List; -import java.util.stream.Collectors; public class ArrayConcatSqlAggregator implements SqlAggregator { @@ -67,21 +64,15 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final List arguments = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List arguments = inputAccessor.getFields(aggregateCall.getArgList()); Integer maxSizeBytes = null; if (arguments.size() > 1) { @@ -92,7 +83,7 @@ public Aggregation toDruidAggregation( } maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); } - final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0)); + final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0)); final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable(); final String fieldName; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index 5136ed3c947b..9af5210905ef 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -21,9 +21,7 @@ import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -41,18 +39,17 @@ import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.List; -import java.util.stream.Collectors; public class ArraySqlAggregator implements SqlAggregator { @@ -69,21 +66,16 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) { - final List arguments = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List arguments = + inputAccessor.getFields(aggregateCall.getArgList()); Integer maxSizeBytes = null; if (arguments.size() > 1) { @@ -94,7 +86,7 @@ public Aggregation toDruidAggregation( } maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); } - final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0)); + final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0)); if (arg == null) { // can't translate argument return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java index a938bdca0b84..3814f8d9ad8b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java @@ -22,8 +22,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -33,14 +31,13 @@ import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; -import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -58,23 +55,19 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final List arguments = Aggregations.getArgumentsForSimpleAggregator( - rexBuilder, plannerContext, - rowSignature, aggregateCall, - project + inputAccessor ); if (arguments == null) { @@ -85,11 +78,11 @@ public Aggregation toDruidAggregation( final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory( countName, plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), virtualColumnRegistry, - rexBuilder, + inputAccessor.getRexBuilder(), aggregateCall, - project + inputAccessor ); final DruidExpression arg = Iterables.getOnlyElement(arguments); @@ -108,12 +101,8 @@ public Aggregation toDruidAggregation( if (arg.isDirectColumnAccess()) { fieldName = arg.getDirectColumn(); } else { - final RexNode resolutionArg = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); + final RexNode resolutionArg = inputAccessor.getField( + Iterables.getOnlyElement(aggregateCall.getArgList())); fieldName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, resolutionArg.getType()); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java index d8758141dfba..a5c7fb61cff8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java @@ -21,8 +21,6 @@ import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; @@ -41,12 +39,12 @@ import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -122,12 +120,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -135,8 +131,8 @@ public Aggregation toDruidAggregation( final List arguments = aggregateCall .getArgList() .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .map(i -> inputAccessor.getField(i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode)) .collect(Collectors.toList()); if (arguments.stream().anyMatch(Objects::isNull)) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java index e4dedd95ce27..699c7a8d1c6b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BuiltinApproxCountDistinctSqlAggregator.java @@ -22,9 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; @@ -42,7 +40,6 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -50,6 +47,7 @@ import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -72,26 +70,20 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access // for string columns. - final RexNode rexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); + final RexNode rexNode = inputAccessor.getField( + Iterables.getOnlyElement(aggregateCall.getArgList())); - final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, rexNode); + final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode); if (arg == null) { return null; } @@ -100,7 +92,10 @@ public Aggregation toDruidAggregation( final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name; if (arg.isDirectColumnAccess() - && rowSignature.getColumnType(arg.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) { + && inputAccessor.getInputRowSignature() + .getColumnType(arg.getDirectColumn()) + .map(type -> type.is(ValueType.COMPLEX)) + .orElse(false)) { aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true); } else { final RelDataType dataType = rexNode.getType(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java index edc7e3ce50a0..c28ac8eebb28 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java @@ -22,7 +22,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -40,6 +39,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -69,15 +69,10 @@ static AggregatorFactory createCountAggregatorFactory( final VirtualColumnRegistry virtualColumnRegistry, final RexBuilder rexBuilder, final AggregateCall aggregateCall, - final Project project + final InputAccessor inputAccessor ) { - final RexNode rexNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - Iterables.getOnlyElement(aggregateCall.getArgList()) - ); + final RexNode rexNode = inputAccessor.getField(Iterables.getOnlyElement(aggregateCall.getArgList())); if (rexNode.getType().isNullable()) { final DimFilter nonNullFilter = Expressions.toFilter( @@ -102,28 +97,25 @@ static AggregatorFactory createCountAggregatorFactory( @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { final List args = Aggregations.getArgumentsForSimpleAggregator( - rexBuilder, plannerContext, - rowSignature, aggregateCall, - project + inputAccessor ); if (args == null) { return null; } + // FIXME: is-all-literal if (args.isEmpty()) { // COUNT(*) return Aggregation.create(new CountAggregatorFactory(name)); @@ -132,12 +124,10 @@ public Aggregation toDruidAggregation( if (plannerContext.getPlannerConfig().isUseApproximateCountDistinct()) { return approxCountDistinctAggregator.toDruidAggregation( plannerContext, - rowSignature, virtualColumnRegistry, - rexBuilder, name, aggregateCall, - project, + inputAccessor, existingAggregations, finalizeAggregations ); @@ -150,11 +140,11 @@ public Aggregation toDruidAggregation( AggregatorFactory theCount = createCountAggregatorFactory( name, plannerContext, - rowSignature, + inputAccessor.getInputRowSignature(), virtualColumnRegistry, - rexBuilder, + inputAccessor.getRexBuilder(), aggregateCall, - project + inputAccessor ); return Aggregation.create(theCount); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java index 5f1b3c3228d4..abaeede99484 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java @@ -20,9 +20,7 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -53,19 +51,18 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; public class EarliestLatestAnySqlAggregator implements SqlAggregator { @@ -180,23 +177,17 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { - final List rexNodes = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List rexNodes = inputAccessor.getFields(aggregateCall.getArgList()); - final List args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes); + final List args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes); if (args == null) { return null; @@ -216,7 +207,8 @@ public Aggregation toDruidAggregation( final String fieldName = getColumnName(plannerContext, virtualColumnRegistry, args.get(0), rexNodes.get(0)); - if (!rowSignature.contains(ColumnHolder.TIME_COLUMN_NAME) && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) { + if (!inputAccessor.getInputRowSignature().contains(ColumnHolder.TIME_COLUMN_NAME) + && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) { // This code is being run as part of the exploratory volcano planner, currently, the definition of these // aggregators does not tell Calcite that they depend on a __time column being in existence, instead we are // allowing the volcano planner to explore paths that put projections which eliminate the time column in between diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java index 95b70e1f1e50..c12be459cf55 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java @@ -20,8 +20,6 @@ package org.apache.druid.sql.calcite.aggregation.builtin; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -38,19 +36,18 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; public class EarliestLatestBySqlAggregator implements SqlAggregator { @@ -76,23 +73,17 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) { - final List rexNodes = aggregateCall - .getArgList() - .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .collect(Collectors.toList()); + final List rexNodes = inputAccessor.getFields(aggregateCall.getArgList()); - final List args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes); + final List args = Expressions.toDruidExpressions(plannerContext, inputAccessor.getInputRowSignature(), rexNodes); if (args == null) { return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java index 156c3995c6fb..ec829df11d74 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java @@ -22,7 +22,6 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -34,6 +33,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -53,24 +53,22 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, - List existingAggregations, - boolean finalizeAggregations + final InputAccessor inputAccessor, + final List existingAggregations, + final boolean finalizeAggregations ) { List arguments = aggregateCall.getArgList() .stream() .map(i -> getColumnName( plannerContext, - rowSignature, - project, + inputAccessor.getInputRowSignature(), + inputAccessor.getProject(), virtualColumnRegistry, - rexBuilder.getTypeFactory(), + inputAccessor.getRexBuilder().getTypeFactory(), i )) .filter(Objects::nonNull) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java index 0eb2c1085c04..6e7de762b23a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/LiteralSqlAggregator.java @@ -21,18 +21,16 @@ import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlInternalOperators; import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -59,12 +57,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) @@ -73,7 +69,7 @@ public Aggregation toDruidAggregation( return null; } final RexNode literal = aggregateCall.rexList.get(0); - final DruidExpression expr = Expressions.toDruidExpression(plannerContext, rowSignature, literal); + final DruidExpression expr = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), literal); if (expr == null) { return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java index 01782668663d..5da064c285d5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/SimpleSqlAggregator.java @@ -21,18 +21,16 @@ import com.google.common.collect.Iterables; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.druid.error.DruidException; import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -57,12 +55,10 @@ public static DruidException badTypeException(String columnName, String agg, Col @Override public Aggregation toDruidAggregation( final PlannerContext plannerContext, - final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, - final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final boolean finalizeAggregations ) @@ -72,11 +68,9 @@ public Aggregation toDruidAggregation( } final List arguments = Aggregations.getArgumentsForSimpleAggregator( - rexBuilder, plannerContext, - rowSignature, aggregateCall, - project + inputAccessor ); if (arguments == null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java index b391100ff3a1..7c1389de3fe3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java @@ -21,9 +21,7 @@ import com.google.common.collect.ImmutableSet; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; @@ -47,13 +45,13 @@ import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.segment.column.ColumnType; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import org.apache.druid.sql.calcite.table.RowSignatures; @@ -89,12 +87,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) @@ -102,20 +98,15 @@ public Aggregation toDruidAggregation( final List arguments = aggregateCall .getArgList() .stream() - .map(i -> Expressions.fromFieldAccess(rexBuilder.getTypeFactory(), rowSignature, project, i)) - .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .map(i -> inputAccessor.getField(i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode)) .collect(Collectors.toList()); if (arguments.stream().anyMatch(Objects::isNull)) { return null; } - RexNode separatorNode = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ); + RexNode separatorNode = inputAccessor.getField(aggregateCall.getArgList().get(1)); if (!separatorNode.isA(SqlKind.LITERAL)) { // separator must be a literal return null; @@ -133,12 +124,7 @@ public Aggregation toDruidAggregation( Integer maxSizeBytes = null; if (arguments.size() > 2) { - RexNode maxBytes = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ); + RexNode maxBytes = inputAccessor.getField(aggregateCall.getArgList().get(2)); if (!maxBytes.isA(SqlKind.LITERAL)) { // maxBytes must be a literal return null; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java index 7dd158d91f3a..00cd391eab20 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/WindowSqlAggregate.java @@ -20,14 +20,12 @@ package org.apache.druid.sql.calcite.expression; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlAggFunction; import org.apache.druid.java.util.common.UOE; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.aggregation.Aggregation; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -55,12 +53,10 @@ public SqlAggFunction calciteFunction() @Override public Aggregation toDruidAggregation( PlannerContext plannerContext, - RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, - RexBuilder rexBuilder, String name, AggregateCall aggregateCall, - Project project, + InputAccessor inputAccessor, List existingAggregations, boolean finalizeAggregations ) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index 9c41d79070bf..1cf79b6dc123 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -580,7 +580,11 @@ private static List computeAggregations( rowSignature, virtualColumnRegistry, rexBuilder, - partialQuery.getSelectProject(), + InputAccessor.buildFor( + rexBuilder, + rowSignature, + partialQuery.getSelectProject(), + null), aggregations, aggName, aggCall, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java new file mode 100644 index 000000000000..57b81c685368 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rel; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.expression.Expressions; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Enables simpler access to input expressions. + * + * In case of aggregates it provides the constants transparently for aggregates. + */ +public class InputAccessor +{ + private final Project project; + private final ImmutableList constants; + private final RexBuilder rexBuilder; + private final RowSignature inputRowSignature; + private final int inputFieldCount; + + public static InputAccessor buildFor( + RexBuilder rexBuilder, + RowSignature inputRowSignature, + @Nullable Project project, + @Nullable ImmutableList constants) + { + return new InputAccessor(rexBuilder, inputRowSignature, project, constants); + } + + private InputAccessor( + RexBuilder rexBuilder, + RowSignature inputRowSignature, + Project project, + ImmutableList constants) + { + this.rexBuilder = rexBuilder; + this.inputRowSignature = inputRowSignature; + this.project = project; + this.constants = constants; + this.inputFieldCount = project != null ? project.getRowType().getFieldCount() : inputRowSignature.size(); + } + + public RexNode getField(int argIndex) + { + + if (argIndex < inputFieldCount) { + return Expressions.fromFieldAccess( + rexBuilder.getTypeFactory(), + inputRowSignature, + project, + argIndex); + } else { + return constants.get(argIndex - inputFieldCount); + } + } + + public List getFields(List argList) + { + return argList + .stream() + .map(i -> getField(i)) + .collect(Collectors.toList()); + } + + public @Nullable Project getProject() + { + return project; + } + + + public RexBuilder getRexBuilder() + { + return rexBuilder; + } + + + public RowSignature getInputRowSignature() + { + return inputRowSignature; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java index 07c5544441dc..4039ca8914ab 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java @@ -177,7 +177,11 @@ public static Windowing fromCalciteStuff( sourceRowSignature, null, rexBuilder, - partialQuery.getSelectProject(), + InputAccessor.buildFor( + rexBuilder, + sourceRowSignature, + partialQuery.getSelectProject(), + window.constants), Collections.emptyList(), aggName, aggregateCall, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java index 50bdf80771a8..fecabd00ec39 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.rule; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.core.Project; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -32,6 +31,7 @@ import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; @@ -58,7 +58,7 @@ public static Aggregation translateAggregateCall( final RowSignature rowSignature, @Nullable final VirtualColumnRegistry virtualColumnRegistry, final RexBuilder rexBuilder, - final Project project, + final InputAccessor inputAccessor, final List existingAggregations, final String name, final AggregateCall call, @@ -74,11 +74,7 @@ public static Aggregation translateAggregateCall( if (call.filterArg >= 0) { // AGG(xxx) FILTER(WHERE yyy) - final RexNode expression = Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - call.filterArg); + final RexNode expression = inputAccessor.getField(call.filterArg); final DimFilter nonOptimizedFilter = Expressions.toFilter( plannerContext, @@ -136,12 +132,10 @@ public static Aggregation translateAggregateCall( final Aggregation retVal = sqlAggregator.toDruidAggregation( plannerContext, - rowSignature, virtualColumnRegistry, - rexBuilder, name, call, - project, + inputAccessor, existingAggregationsWithSameFilter, finalizeAggregations ); diff --git a/sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest b/sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest new file mode 100644 index 000000000000..16dbe924fdb3 --- /dev/null +++ b/sql/src/test/resources/calcite/tests/window/aggregateConstant.sqlTest @@ -0,0 +1,26 @@ +type: "operatorValidation" + +sql: | + SELECT + dim1, + count(333) OVER () cc + FROM foo + WHERE length(dim1)>0 + +expectedOperators: + - type: naivePartition + partitionColumns: [] + - type: "window" + processor: + type: "framedAgg" + frame: { peerType: "ROWS", lowUnbounded: true, lowOffset: 0, uppUnbounded: true, uppOffset: 0 } + aggregations: + - { type: "count", name: "w0" } + +expectedResults: + - ["10.1",5] + - ["2",5] + - ["1",5] + - ["def",5] + - ["abc",5] + From c7d0615af3fe6fb7c500efd591299ae994225fe9 Mon Sep 17 00:00:00 2001 From: Pranav Date: Sun, 8 Oct 2023 21:05:39 -0700 Subject: [PATCH 05/14] Fix the build for #15013.: Lookup jitter upstream build fix (#15103) Fix the build for #15013. --- .../extensions-core/lookups-cached-global.md | 2 + .../lookup/namespace/ExtractionNamespace.java | 8 ++++ .../namespace/JdbcExtractionNamespace.java | 14 ++++++ .../namespace/cache/CacheScheduler.java | 4 +- .../JdbcExtractionNamespaceUrlCheckTest.java | 12 +++++- .../namespace/JdbcCacheGeneratorTest.java | 1 + .../namespace/cache/CacheSchedulerTest.java | 1 + .../cache/JdbcExtractionNamespaceTest.java | 43 +++++++++++++++++++ 8 files changed, 81 insertions(+), 4 deletions(-) diff --git a/docs/development/extensions-core/lookups-cached-global.md b/docs/development/extensions-core/lookups-cached-global.md index ebeca5a741e0..dc8827a5b368 100644 --- a/docs/development/extensions-core/lookups-cached-global.md +++ b/docs/development/extensions-core/lookups-cached-global.md @@ -352,6 +352,7 @@ The JDBC lookups will poll a database to populate its local cache. If the `tsCol |`filter`|The filter to use when selecting lookups, this is used to create a where clause on lookup population|No|No Filter| |`tsColumn`| The column in `table` which contains when the key was updated|No|Not used| |`pollPeriod`|How often to poll the DB|No|0 (only once)| +|`jitterSeconds`| How much jitter to add (in seconds) up to maximum as a delay (actual value will be used as random from 0 to `jitterSeconds`), used to distribute db load more evenly|No|0| |`maxHeapPercentage`|The maximum percentage of heap size that the lookup should consume. If the lookup grows beyond this size, warning messages will be logged in the respective service logs.|No|10% of JVM heap size| ```json @@ -367,6 +368,7 @@ The JDBC lookups will poll a database to populate its local cache. If the `tsCol "valueColumn":"the_new_dim_value", "tsColumn":"timestamp_column", "pollPeriod":600000, + "jitterSeconds": 120, "maxHeapPercentage": 10 } ``` diff --git a/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/ExtractionNamespace.java b/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/ExtractionNamespace.java index 86eb310b4df9..c52021bd18f7 100644 --- a/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/ExtractionNamespace.java +++ b/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/ExtractionNamespace.java @@ -41,4 +41,12 @@ default long getMaxHeapPercentage() { return -1L; } + + // For larger clusters, when they all startup at the same time and have lookups in the db, + // it overwhelms the database, this allows implementations to introduce a jitter, which + // should spread out the load. + default long getJitterMills() + { + return 0; + } } diff --git a/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespace.java b/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespace.java index 1495370a4519..32ceccd1a82a 100644 --- a/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespace.java +++ b/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespace.java @@ -34,6 +34,7 @@ import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; import java.util.Objects; +import java.util.concurrent.ThreadLocalRandom; /** * @@ -61,6 +62,8 @@ public class JdbcExtractionNamespace implements ExtractionNamespace private final Period pollPeriod; @JsonProperty private final long maxHeapPercentage; + @JsonProperty + private final int jitterSeconds; @JsonCreator public JdbcExtractionNamespace( @@ -73,6 +76,7 @@ public JdbcExtractionNamespace( @JsonProperty(value = "filter") @Nullable final String filter, @Min(0) @JsonProperty(value = "pollPeriod") @Nullable final Period pollPeriod, @JsonProperty(value = "maxHeapPercentage") @Nullable final Long maxHeapPercentage, + @JsonProperty(value = "jitterSeconds") @Nullable Integer jitterSeconds, @JacksonInject JdbcAccessSecurityConfig securityConfig ) { @@ -95,6 +99,7 @@ public JdbcExtractionNamespace( } else { this.pollPeriod = pollPeriod; } + this.jitterSeconds = jitterSeconds == null ? 0 : jitterSeconds; this.maxHeapPercentage = maxHeapPercentage == null ? DEFAULT_MAX_HEAP_PERCENTAGE : maxHeapPercentage; } @@ -162,6 +167,15 @@ public long getMaxHeapPercentage() return maxHeapPercentage; } + @Override + public long getJitterMills() + { + if (jitterSeconds == 0) { + return jitterSeconds; + } + return 1000L * ThreadLocalRandom.current().nextInt(jitterSeconds + 1); + } + @Override public String toString() { diff --git a/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/server/lookup/namespace/cache/CacheScheduler.java b/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/server/lookup/namespace/cache/CacheScheduler.java index 63471afe3d17..61e580563f8c 100644 --- a/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/server/lookup/namespace/cache/CacheScheduler.java +++ b/extensions-core/lookups-cached-global/src/main/java/org/apache/druid/server/lookup/namespace/cache/CacheScheduler.java @@ -180,9 +180,9 @@ private Future schedule(final T namespace) final long updateMs = namespace.getPollMs(); Runnable command = this::updateCache; if (updateMs > 0) { - return cacheManager.scheduledExecutorService().scheduleAtFixedRate(command, 0, updateMs, TimeUnit.MILLISECONDS); + return cacheManager.scheduledExecutorService().scheduleAtFixedRate(command, namespace.getJitterMills(), updateMs, TimeUnit.MILLISECONDS); } else { - return cacheManager.scheduledExecutorService().schedule(command, 0, TimeUnit.MILLISECONDS); + return cacheManager.scheduledExecutorService().schedule(command, namespace.getJitterMills(), TimeUnit.MILLISECONDS); } } diff --git a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespaceUrlCheckTest.java b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespaceUrlCheckTest.java index 44bb67eac00d..f4fffef5fff3 100644 --- a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespaceUrlCheckTest.java +++ b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/query/lookup/namespace/JdbcExtractionNamespaceUrlCheckTest.java @@ -63,7 +63,7 @@ public String getConnectURI() "some filter", new Period(10), null, - new JdbcAccessSecurityConfig() + 0, new JdbcAccessSecurityConfig() { @Override public Set getAllowedProperties() @@ -101,6 +101,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override @@ -137,6 +138,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override @@ -175,6 +177,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override @@ -217,6 +220,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override @@ -255,6 +259,7 @@ public String getConnectURI() "some filter", new Period(10), 10L, + 0, new JdbcAccessSecurityConfig() { @Override @@ -291,7 +296,7 @@ public String getConnectURI() "some filter", new Period(10), null, - new JdbcAccessSecurityConfig() + 0, new JdbcAccessSecurityConfig() { @Override public Set getAllowedProperties() @@ -329,6 +334,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override @@ -373,6 +379,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override @@ -415,6 +422,7 @@ public String getConnectURI() "some filter", new Period(10), null, + 0, new JdbcAccessSecurityConfig() { @Override diff --git a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/JdbcCacheGeneratorTest.java b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/JdbcCacheGeneratorTest.java index ff27b50fd86e..1eb74630fda0 100644 --- a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/JdbcCacheGeneratorTest.java +++ b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/JdbcCacheGeneratorTest.java @@ -137,6 +137,7 @@ private static JdbcExtractionNamespace createJdbcExtractionNamespace( "filter", Period.ZERO, null, + 0, new JdbcAccessSecurityConfig() ); } diff --git a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/CacheSchedulerTest.java b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/CacheSchedulerTest.java index 44289b048339..fd96529ae992 100644 --- a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/CacheSchedulerTest.java +++ b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/CacheSchedulerTest.java @@ -458,6 +458,7 @@ public String getConnectURI() "some filter", new Period(10_000), null, + 0, new JdbcAccessSecurityConfig() { @Override diff --git a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/JdbcExtractionNamespaceTest.java b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/JdbcExtractionNamespaceTest.java index b6a37240cea4..e0c651724d75 100644 --- a/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/JdbcExtractionNamespaceTest.java +++ b/extensions-core/lookups-cached-global/src/test/java/org/apache/druid/server/lookup/namespace/cache/JdbcExtractionNamespaceTest.java @@ -328,6 +328,7 @@ public void testMappingWithoutFilter() null, new Period(0), null, + 0, new JdbcAccessSecurityConfig() ); try (CacheScheduler.Entry entry = scheduler.schedule(extractionNamespace)) { @@ -361,6 +362,7 @@ public void testMappingWithFilter() FILTER_COLUMN + "='1'", new Period(0), null, + 0, new JdbcAccessSecurityConfig() ); try (CacheScheduler.Entry entry = scheduler.schedule(extractionNamespace)) { @@ -399,6 +401,45 @@ public void testSkipOld() } } + @Test + public void testRandomJitter() + { + JdbcExtractionNamespace extractionNamespace = new JdbcExtractionNamespace( + derbyConnectorRule.getMetadataConnectorConfig(), + TABLE_NAME, + KEY_NAME, + VAL_NAME, + tsColumn, + FILTER_COLUMN + "='1'", + new Period(0), + null, + 120, + new JdbcAccessSecurityConfig() + ); + long jitter = extractionNamespace.getJitterMills(); + // jitter will be a random value between 0 and 120 seconds. + Assert.assertTrue(jitter >= 0 && jitter <= 120000); + } + + @Test + public void testRandomJitterNotSpecified() + { + JdbcExtractionNamespace extractionNamespace = new JdbcExtractionNamespace( + derbyConnectorRule.getMetadataConnectorConfig(), + TABLE_NAME, + KEY_NAME, + VAL_NAME, + tsColumn, + FILTER_COLUMN + "='1'", + new Period(0), + null, + null, + new JdbcAccessSecurityConfig() + ); + // jitter will be a random value between 0 and 120 seconds. + Assert.assertEquals(0, extractionNamespace.getJitterMills()); + } + @Test(timeout = 60_000L) public void testFindNew() throws InterruptedException @@ -436,6 +477,7 @@ public void testSerde() throws IOException "some filter", new Period(10), null, + 0, securityConfig ); final ObjectMapper mapper = new DefaultObjectMapper(); @@ -461,6 +503,7 @@ private CacheScheduler.Entry ensureEntry() null, new Period(10), null, + 0, new JdbcAccessSecurityConfig() ); CacheScheduler.Entry entry = scheduler.schedule(extractionNamespace); From c483cb863d420461acd3a98666eec17716c62421 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Sun, 8 Oct 2023 22:42:28 -0700 Subject: [PATCH 06/14] Fix IndexerWorkerClient#fetchChannelData when response has data and error. (#15084) * Fix IndexerWorkerClient#fetchChannelData when response has data and error. When a channel data response from a worker includes some data and then some I/O error, then when the call is retried, we will re-read the set of data that was read by the previous connection and add it to the local channel again. This causes the local channel to become corrupted. The patch fixes this case by skipping data that has already been read. --- .../file/FrameFileHttpResponseHandler.java | 46 ++++++++--- .../frame/file/FrameFilePartialFetch.java | 11 +++ .../FrameFileHttpResponseHandlerTest.java | 80 +++++++++++++++++++ 3 files changed, 126 insertions(+), 11 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/frame/file/FrameFileHttpResponseHandler.java b/processing/src/main/java/org/apache/druid/frame/file/FrameFileHttpResponseHandler.java index 0f70fb3d6983..661ba351dfc2 100644 --- a/processing/src/main/java/org/apache/druid/frame/file/FrameFileHttpResponseHandler.java +++ b/processing/src/main/java/org/apache/druid/frame/file/FrameFileHttpResponseHandler.java @@ -21,6 +21,7 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.channel.ReadableByteChunksFrameChannel; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.http.client.response.ClientResponse; @@ -49,11 +50,20 @@ public class FrameFileHttpResponseHandler implements HttpResponseHandler response( return ClientResponse.finished(clientResponseObj); } - final byte[] chunk = new byte[content.readableBytes()]; - content.getBytes(content.readerIndex(), chunk); + final byte[] chunk; + final int chunkSize = content.readableBytes(); - try { - final ListenableFuture backpressureFuture = channel.addChunk(chunk); + // Potentially skip some of this chunk, if the relevant bytes have already been read by the handler. This can + // happen if a request reads some data, then fails with a retryable I/O error, and then is retried. The retry + // will re-read some data that has already been added to the channel, so we need to skip it. + final long readByThisHandler = channel.getBytesAdded() - startOffset; + final long readByThisRequest = clientResponseObj.getBytesRead(); // Prior to the current chunk + final long toSkip = readByThisHandler - readByThisRequest; - if (backpressureFuture != null) { - clientResponseObj.setBackpressureFuture(backpressureFuture); - } + if (toSkip < 0) { + throw DruidException.defensive("Expected toSkip[%d] to be nonnegative", toSkip); + } else if (toSkip < chunkSize) { // When toSkip >= chunkSize, we skip the entire chunk and do not toucn the channel + chunk = new byte[chunkSize - (int) toSkip]; + content.getBytes(content.readerIndex() + (int) toSkip, chunk); - clientResponseObj.addBytesRead(chunk.length); - } - catch (Exception e) { - clientResponseObj.exceptionCaught(e); + try { + final ListenableFuture backpressureFuture = channel.addChunk(chunk); + + if (backpressureFuture != null) { + clientResponseObj.setBackpressureFuture(backpressureFuture); + } + } + catch (Exception e) { + clientResponseObj.exceptionCaught(e); + } } + // Call addBytesRead even if we skipped some or all of the chunk, because that lets us know when to stop skipping. + clientResponseObj.addBytesRead(chunkSize); return ClientResponse.unfinished(clientResponseObj); } } diff --git a/processing/src/main/java/org/apache/druid/frame/file/FrameFilePartialFetch.java b/processing/src/main/java/org/apache/druid/frame/file/FrameFilePartialFetch.java index 8c2056dcbe43..9e6b84c6bbf7 100644 --- a/processing/src/main/java/org/apache/druid/frame/file/FrameFilePartialFetch.java +++ b/processing/src/main/java/org/apache/druid/frame/file/FrameFilePartialFetch.java @@ -74,6 +74,14 @@ public boolean isExceptionCaught() return exceptionCaught != null; } + /** + * Number of bytes read so far by this request. + */ + public long getBytesRead() + { + return bytesRead; + } + /** * Future that resolves when it is a good time to request the next chunk of the frame file. * @@ -105,6 +113,9 @@ void exceptionCaught(final Throwable t) } } + /** + * Increment the value returned by {@link #getBytesRead()}. Called whenever a chunk of data is read from the response. + */ void addBytesRead(final long n) { bytesRead += n; diff --git a/processing/src/test/java/org/apache/druid/frame/file/FrameFileHttpResponseHandlerTest.java b/processing/src/test/java/org/apache/druid/frame/file/FrameFileHttpResponseHandlerTest.java index 4eeaaddbe892..06c160e68409 100644 --- a/processing/src/test/java/org/apache/druid/frame/file/FrameFileHttpResponseHandlerTest.java +++ b/processing/src/test/java/org/apache/druid/frame/file/FrameFileHttpResponseHandlerTest.java @@ -346,6 +346,86 @@ public void testCaughtExceptionDuringChunkedResponse() throws Exception ); } + @Test + public void testCaughtExceptionDuringChunkedResponseRetryWithSameHandler() throws Exception + { + // Split file into 12 chunks after the first 100 bytes. + final int firstPart = 100; + final int chunkSize = Ints.checkedCast(LongMath.divide(file.length() - firstPart, 12, RoundingMode.CEILING)); + final byte[] allBytes = Files.readAllBytes(file.toPath()); + + // Add firstPart and be done. + ClientResponse response = handler.done( + handler.handleResponse( + makeResponse(HttpResponseStatus.OK, byteSlice(allBytes, 0, firstPart)), + null + ) + ); + + Assert.assertEquals(firstPart, channel.getBytesAdded()); + Assert.assertTrue(response.isFinished()); + + // Add first quarter after firstPart using a new handler. + handler = new FrameFileHttpResponseHandler(channel); + response = handler.handleResponse( + makeResponse(HttpResponseStatus.OK, byteSlice(allBytes, firstPart, chunkSize * 3)), + null + ); + + // Set an exception. + handler.exceptionCaught(response, new ISE("Oh no!")); + + // Add another chunk after the exception is caught (this can happen in real life!). We expect it to be ignored. + response = handler.handleChunk( + response, + makeChunk(byteSlice(allBytes, firstPart + chunkSize * 3, chunkSize * 3)), + 2 + ); + + // Verify that the exception handler was called. + Assert.assertTrue(response.getObj().isExceptionCaught()); + final Throwable e = response.getObj().getExceptionCaught(); + MatcherAssert.assertThat(e, CoreMatchers.instanceOf(IllegalStateException.class)); + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("Oh no!"))); + + // Retry connection with the same handler and same initial offset firstPart (don't recreate handler), but now use + // thirds instead of quarters as chunks. (ServiceClientImpl would retry from the same offset with the same handler + // if the exception is retryable.) + response = handler.handleResponse( + makeResponse(HttpResponseStatus.OK, byteSlice(allBytes, firstPart, chunkSize * 4)), + null + ); + + Assert.assertEquals(firstPart + chunkSize * 4L, channel.getBytesAdded()); + Assert.assertFalse(response.isFinished()); + + // Send the rest of the data. + response = handler.handleChunk( + response, + makeChunk(byteSlice(allBytes, firstPart + chunkSize * 4, chunkSize * 4)), + 1 + ); + Assert.assertEquals(firstPart + chunkSize * 8L, channel.getBytesAdded()); + + response = handler.handleChunk( + response, + makeChunk(byteSlice(allBytes, firstPart + chunkSize * 8, chunkSize * 4)), + 2 + ); + response = handler.done(response); + + Assert.assertTrue(response.isFinished()); + Assert.assertFalse(response.getObj().isExceptionCaught()); + + // Verify channel. + Assert.assertEquals(allBytes.length, channel.getBytesAdded()); + channel.doneWriting(); + FrameTestUtil.assertRowsEqual( + FrameTestUtil.readRowsFromAdapter(adapter, null, false), + FrameTestUtil.readRowsFromFrameChannel(channel, FrameReader.create(adapter.getRowSignature())) + ); + } + private static HttpResponse makeResponse(final HttpResponseStatus status, final byte[] content) { final ByteBufferBackedChannelBuffer channelBuffer = new ByteBufferBackedChannelBuffer(ByteBuffer.wrap(content)); From e2cc1c4ad19dbf005a947f25f3833c55da76f1fb Mon Sep 17 00:00:00 2001 From: kaisun2000 <52840222+kaisun2000@users.noreply.github.com> Date: Mon, 9 Oct 2023 00:26:23 -0700 Subject: [PATCH 07/14] Add metric -- count of queries waiting for merge buffers (#15025) Add 'mergeBuffer/pendingRequests' metric that exposes the count of waiting queries (threads) blocking in the merge buffers pools. --- docs/operations/metrics.md | 2 + .../druid/collections/BlockingPool.java | 9 ++- .../collections/DefaultBlockingPool.java | 18 ++++++ .../druid/collections/DummyBlockingPool.java | 6 ++ .../apache/druid/query/TestBufferPool.java | 6 ++ .../metrics/QueryCountStatsMonitor.java | 11 +++- .../metrics/QueryCountStatsMonitorTest.java | 57 ++++++++++++++++++- 7 files changed, 104 insertions(+), 5 deletions(-) diff --git a/docs/operations/metrics.md b/docs/operations/metrics.md index 77a79170bec8..28e8a9fa9646 100644 --- a/docs/operations/metrics.md +++ b/docs/operations/metrics.md @@ -62,6 +62,7 @@ Most metric values reset each emission period, as specified in `druid.monitoring |`query/failed/count`|Number of failed queries.|This metric is only available if the `QueryCountStatsMonitor` module is included.| | |`query/interrupted/count`|Number of queries interrupted due to cancellation.|This metric is only available if the `QueryCountStatsMonitor` module is included.| | |`query/timeout/count`|Number of timed out queries.|This metric is only available if the `QueryCountStatsMonitor` module is included.| | +|`mergeBuffer/pendingRequests`|Number of requests waiting to acquire a batch of buffers from the merge buffer pool.|This metric is only available if the `QueryCountStatsMonitor` module is included.| | |`query/segments/count`|This metric is not enabled by default. See the `QueryMetrics` Interface for reference regarding enabling this metric. Number of segments that will be touched by the query. In the broker, it makes a plan to distribute the query to realtime tasks and historicals based on a snapshot of segment distribution state. If there are some segments moved after this snapshot is created, certain historicals and realtime tasks can report those segments as missing to the broker. The broker will resend the query to the new servers that serve those segments after move. In this case, those segments can be counted more than once in this metric.||Varies| |`query/priority`|Assigned lane and priority, only if Laning strategy is enabled. Refer to [Laning strategies](../configuration/index.md#laning-strategies)|`lane`, `dataSource`, `type`|0| |`sqlQuery/time`|Milliseconds taken to complete a SQL query.|`id`, `nativeQueryIds`, `dataSource`, `remoteAddress`, `success`, `engine`|< 1s| @@ -97,6 +98,7 @@ Most metric values reset each emission period, as specified in `druid.monitoring |`query/failed/count`|Number of failed queries.|This metric is only available if the `QueryCountStatsMonitor` module is included.|| |`query/interrupted/count`|Number of queries interrupted due to cancellation.|This metric is only available if the `QueryCountStatsMonitor` module is included.|| |`query/timeout/count`|Number of timed out queries.|This metric is only available if the `QueryCountStatsMonitor` module is included.|| +|`mergeBuffer/pendingRequests`|Number of requests waiting to acquire a batch of buffers from the merge buffer pool.|This metric is only available if the `QueryCountStatsMonitor` module is included.|| ### Real-time diff --git a/processing/src/main/java/org/apache/druid/collections/BlockingPool.java b/processing/src/main/java/org/apache/druid/collections/BlockingPool.java index c17329917cd2..4fb3ff66d8bf 100644 --- a/processing/src/main/java/org/apache/druid/collections/BlockingPool.java +++ b/processing/src/main/java/org/apache/druid/collections/BlockingPool.java @@ -31,7 +31,6 @@ public interface BlockingPool * * @param elementNum number of resources to take * @param timeoutMs maximum time to wait for resources, in milliseconds. - * * @return a list of resource holders. An empty list is returned if {@code elementNum} resources aren't available. */ List> takeBatch(int elementNum, long timeoutMs); @@ -40,8 +39,14 @@ public interface BlockingPool * Take resources from the pool, waiting if necessary until the elements of the given number become available. * * @param elementNum number of resources to take - * * @return a list of resource holders. An empty list is returned if {@code elementNum} resources aren't available. */ List> takeBatch(int elementNum); + + /** + * Returns the count of the requests waiting to acquire a batch of resources. + * + * @return count of pending requests + */ + long getPendingRequests(); } diff --git a/processing/src/main/java/org/apache/druid/collections/DefaultBlockingPool.java b/processing/src/main/java/org/apache/druid/collections/DefaultBlockingPool.java index 1021974b1b4e..e41a9e5d75d4 100644 --- a/processing/src/main/java/org/apache/druid/collections/DefaultBlockingPool.java +++ b/processing/src/main/java/org/apache/druid/collections/DefaultBlockingPool.java @@ -30,6 +30,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; import java.util.stream.Collectors; @@ -48,6 +49,8 @@ public class DefaultBlockingPool implements BlockingPool private final Condition notEnough; private final int maxSize; + private final AtomicLong pendingRequests; + public DefaultBlockingPool( Supplier generator, int limit @@ -62,6 +65,7 @@ public DefaultBlockingPool( this.lock = new ReentrantLock(); this.notEnough = lock.newCondition(); + this.pendingRequests = new AtomicLong(); } @Override @@ -91,12 +95,16 @@ public List> takeBatch(final int elementNum, Preconditions.checkArgument(timeoutMs >= 0, "timeoutMs must be a non-negative value, but was [%s]", timeoutMs); checkInitialized(); try { + pendingRequests.incrementAndGet(); final List objects = timeoutMs > 0 ? pollObjects(elementNum, timeoutMs) : pollObjects(elementNum); return objects.stream().map(this::wrapObject).collect(Collectors.toList()); } catch (InterruptedException e) { throw new RuntimeException(e); } + finally { + pendingRequests.decrementAndGet(); + } } @Override @@ -104,11 +112,21 @@ public List> takeBatch(final int elementNum) { checkInitialized(); try { + pendingRequests.incrementAndGet(); return takeObjects(elementNum).stream().map(this::wrapObject).collect(Collectors.toList()); } catch (InterruptedException e) { throw new RuntimeException(e); } + finally { + pendingRequests.incrementAndGet(); + } + } + + @Override + public long getPendingRequests() + { + return pendingRequests.get(); } private List pollObjects(int elementNum) throws InterruptedException diff --git a/processing/src/main/java/org/apache/druid/collections/DummyBlockingPool.java b/processing/src/main/java/org/apache/druid/collections/DummyBlockingPool.java index dcd6cea07aa7..2553f9ab425f 100644 --- a/processing/src/main/java/org/apache/druid/collections/DummyBlockingPool.java +++ b/processing/src/main/java/org/apache/druid/collections/DummyBlockingPool.java @@ -55,4 +55,10 @@ public List> takeBatch(int elementNum) { throw new UnsupportedOperationException(); } + + @Override + public long getPendingRequests() + { + return 0; + } } diff --git a/processing/src/test/java/org/apache/druid/query/TestBufferPool.java b/processing/src/test/java/org/apache/druid/query/TestBufferPool.java index 10690d31be13..a650437f83f0 100644 --- a/processing/src/test/java/org/apache/druid/query/TestBufferPool.java +++ b/processing/src/test/java/org/apache/druid/query/TestBufferPool.java @@ -132,6 +132,12 @@ public List> takeBatch(int elementNu } } + @Override + public long getPendingRequests() + { + return 0; + } + public long getOutstandingObjectCount() { return takenFromMap.size(); diff --git a/server/src/main/java/org/apache/druid/server/metrics/QueryCountStatsMonitor.java b/server/src/main/java/org/apache/druid/server/metrics/QueryCountStatsMonitor.java index da2017dbc00a..ce951d5933f7 100644 --- a/server/src/main/java/org/apache/druid/server/metrics/QueryCountStatsMonitor.java +++ b/server/src/main/java/org/apache/druid/server/metrics/QueryCountStatsMonitor.java @@ -21,24 +21,30 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.guice.annotations.Merging; import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import org.apache.druid.java.util.metrics.AbstractMonitor; import org.apache.druid.java.util.metrics.KeyedDiff; +import java.nio.ByteBuffer; import java.util.Map; public class QueryCountStatsMonitor extends AbstractMonitor { private final KeyedDiff keyedDiff = new KeyedDiff(); private final QueryCountStatsProvider statsProvider; + private final BlockingPool mergeBufferPool; @Inject public QueryCountStatsMonitor( - QueryCountStatsProvider statsProvider + QueryCountStatsProvider statsProvider, + @Merging BlockingPool mergeBufferPool ) { this.statsProvider = statsProvider; + this.mergeBufferPool = mergeBufferPool; } @Override @@ -65,6 +71,9 @@ public boolean doMonitor(ServiceEmitter emitter) emitter.emit(builder.setMetric(diffEntry.getKey(), diffEntry.getValue())); } } + + long pendingQueries = this.mergeBufferPool.getPendingRequests(); + emitter.emit(builder.setMetric("mergeBuffer/pendingRequests", pendingQueries)); return true; } diff --git a/server/src/test/java/org/apache/druid/server/metrics/QueryCountStatsMonitorTest.java b/server/src/test/java/org/apache/druid/server/metrics/QueryCountStatsMonitorTest.java index 95b9f27d1c26..717c95d62c5b 100644 --- a/server/src/test/java/org/apache/druid/server/metrics/QueryCountStatsMonitorTest.java +++ b/server/src/test/java/org/apache/druid/server/metrics/QueryCountStatsMonitorTest.java @@ -19,17 +19,27 @@ package org.apache.druid.server.metrics; +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.collections.DefaultBlockingPool; import org.apache.druid.java.util.metrics.StubServiceEmitter; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.stream.Collectors; public class QueryCountStatsMonitorTest { private QueryCountStatsProvider queryCountStatsProvider; + private BlockingPool mergeBufferPool; + private ExecutorService executorService; @Before public void setUp() @@ -69,14 +79,24 @@ public long getTimedOutQueryCount() return timedOutEmitCount; } }; + + mergeBufferPool = new DefaultBlockingPool(() -> ByteBuffer.allocate(1024), 1); + executorService = Executors.newSingleThreadExecutor(); + } + + @After + public void tearDown() + { + executorService.shutdown(); } @Test public void testMonitor() { - final QueryCountStatsMonitor monitor = new QueryCountStatsMonitor(queryCountStatsProvider); + final QueryCountStatsMonitor monitor = new QueryCountStatsMonitor(queryCountStatsProvider, mergeBufferPool); final StubServiceEmitter emitter = new StubServiceEmitter("service", "host"); monitor.doMonitor(emitter); + emitter.flush(); // Trigger metric emission monitor.doMonitor(emitter); Map resultMap = emitter.getEvents() @@ -85,12 +105,45 @@ public void testMonitor() event -> (String) event.toMap().get("metric"), event -> (Long) event.toMap().get("value") )); - Assert.assertEquals(5, resultMap.size()); + Assert.assertEquals(6, resultMap.size()); Assert.assertEquals(1L, (long) resultMap.get("query/success/count")); Assert.assertEquals(2L, (long) resultMap.get("query/failed/count")); Assert.assertEquals(3L, (long) resultMap.get("query/interrupted/count")); Assert.assertEquals(4L, (long) resultMap.get("query/timeout/count")); Assert.assertEquals(10L, (long) resultMap.get("query/count")); + Assert.assertEquals(0, (long) resultMap.get("mergeBuffer/pendingRequests")); + + } + + @Test(timeout = 2_000L) + public void testMonitoringMergeBuffer() + { + executorService.submit(() -> { + mergeBufferPool.takeBatch(10); + }); + + int count = 0; + try { + // wait at most 10 secs for the executor thread to block + while (mergeBufferPool.getPendingRequests() == 0) { + Thread.sleep(100); + count++; + if (count >= 20) { + break; + } + } + + final QueryCountStatsMonitor monitor = new QueryCountStatsMonitor(queryCountStatsProvider, mergeBufferPool); + final StubServiceEmitter emitter = new StubServiceEmitter("DummyService", "DummyHost"); + boolean ret = monitor.doMonitor(emitter); + Assert.assertTrue(ret); + List numbers = emitter.getMetricValues("mergeBuffer/pendingRequests", Collections.emptyMap()); + Assert.assertEquals(1, numbers.size()); + Assert.assertEquals(1, numbers.get(0).intValue()); + } + catch (InterruptedException e) { + // do nothing + } } } From 7a35ce886d8e0c80653adf01779a97268b9f939d Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Mon, 9 Oct 2023 15:14:03 +0530 Subject: [PATCH 08/14] Add ability for MSQ tasks to query realtime tasks (#15024) This PR aims to add the capabilities to: 1. Fetch the realtime segment metadata from the coordinator server view, 2. Adds the ability for workers to query indexers, similar to how brokers do the same for native queries. --- docs/multi-stage-query/known-issues.md | 2 - docs/multi-stage-query/reference.md | 1 + .../apache/druid/msq/exec/ControllerImpl.java | 69 +++- .../msq/exec/LoadedSegmentDataProvider.java | 268 +++++++++++++++ .../LoadedSegmentDataProviderFactory.java | 94 ++++++ .../apache/druid/msq/exec/SegmentSource.java | 64 ++++ .../apache/druid/msq/exec/WorkerContext.java | 1 + .../org/apache/druid/msq/exec/WorkerImpl.java | 3 +- .../msq/indexing/IndexerFrameContext.java | 11 + .../msq/indexing/IndexerWorkerContext.java | 27 +- .../external/ExternalInputSliceReader.java | 4 +- .../input/inline/InlineInputSliceReader.java | 6 +- .../input/lookup/LookupInputSliceReader.java | 4 +- .../input/table/DataSegmentWithLocation.java | 98 ++++++ .../input/table/RichSegmentDescriptor.java | 55 +++- .../input/table/SegmentWithDescriptor.java | 46 ++- .../input/table/SegmentsInputSliceReader.java | 9 +- .../msq/input/table/TableInputSpecSlicer.java | 3 +- .../apache/druid/msq/kernel/FrameContext.java | 2 + .../msq/querykit/BaseLeafFrameProcessor.java | 8 +- .../msq/querykit/DataSegmentProvider.java | 2 +- .../GroupByPreShuffleFrameProcessor.java | 27 ++ .../scan/ScanQueryFrameProcessor.java | 69 ++++ .../msq/util/MultiStageQueryContext.java | 13 + .../exec/LoadedSegmentDataProviderTest.java | 247 ++++++++++++++ .../druid/msq/exec/MSQLoadedSegmentTests.java | 308 ++++++++++++++++++ .../indexing/IndexerWorkerContextTest.java | 1 + .../table/RichSegmentDescriptorTest.java | 22 +- .../table/SegmentWithDescriptorTest.java | 7 +- .../input/table/SegmentsInputSliceTest.java | 16 +- .../input/table/TableInputSpecSlicerTest.java | 63 ++-- .../msq/test/CalciteArraysQueryMSQTest.java | 4 +- .../druid/msq/test/CalciteMSQTestsHelper.java | 26 ++ .../test/CalciteSelectJoinQueryMSQTest.java | 4 +- .../msq/test/CalciteSelectQueryMSQTest.java | 3 +- .../apache/druid/msq/test/MSQTestBase.java | 23 +- .../msq/test/MSQTestControllerContext.java | 17 +- .../test/MSQTestOverlordServiceClient.java | 18 +- .../druid/msq/test/MSQTestWorkerContext.java | 9 + .../druid/query/IterableRowsCursorHelper.java | 33 ++ .../druid/query/groupby/GroupingEngine.java | 83 ++--- .../query/IterableRowsCursorHelperTest.java | 8 + .../client/coordinator/CoordinatorClient.java | 6 + .../coordinator/CoordinatorClientImpl.java | 33 ++ .../druid/discovery/DataServerClient.java | 175 ++++++++++ .../discovery/DataServerResponseHandler.java | 115 +++++++ .../druid/rpc/FixedSetServiceLocator.java | 85 +++++ .../org/apache/druid/rpc/ServiceLocation.java | 44 +++ .../coordination/DruidServerMetadata.java | 5 +- .../CoordinatorClientImplTest.java | 78 ++++- .../coordinator/NoopCoordinatorClient.java | 7 + .../druid/discovery/DataServerClientTest.java | 119 +++++++ .../druid/rpc/FixedSetServiceLocatorTest.java | 63 ++++ .../apache/druid/rpc/ServiceLocationTest.java | 41 +++ 54 files changed, 2437 insertions(+), 112 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProvider.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentSource.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/DataSegmentWithLocation.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQLoadedSegmentTests.java create mode 100644 server/src/main/java/org/apache/druid/discovery/DataServerClient.java create mode 100644 server/src/main/java/org/apache/druid/discovery/DataServerResponseHandler.java create mode 100644 server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java create mode 100644 server/src/test/java/org/apache/druid/discovery/DataServerClientTest.java create mode 100644 server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java diff --git a/docs/multi-stage-query/known-issues.md b/docs/multi-stage-query/known-issues.md index bccb9779a835..62a31ecf41af 100644 --- a/docs/multi-stage-query/known-issues.md +++ b/docs/multi-stage-query/known-issues.md @@ -39,8 +39,6 @@ an [UnknownError](./reference.md#error_UnknownError) with a message including "N ## `SELECT` Statement -- `SELECT` from a Druid datasource does not include unpublished real-time data. - - `GROUPING SETS` and `UNION ALL` are not implemented. Queries using these features return a [QueryNotSupported](reference.md#error_QueryNotSupported) error. diff --git a/docs/multi-stage-query/reference.md b/docs/multi-stage-query/reference.md index 010bbff2a270..5e80e318b8c8 100644 --- a/docs/multi-stage-query/reference.md +++ b/docs/multi-stage-query/reference.md @@ -247,6 +247,7 @@ The following table lists the context parameters for the MSQ task engine: | `faultTolerance` | SELECT, INSERT, REPLACE

Whether to turn on fault tolerance mode or not. Failed workers are retried based on [Limits](#limits). Cannot be used when `durableShuffleStorage` is explicitly set to false. | `false` | | `selectDestination` | SELECT

Controls where the final result of the select query is written.
Use `taskReport`(the default) to write select results to the task report. This is not scalable since task reports size explodes for large results
Use `durableStorage` to write results to durable storage location. For large results sets, its recommended to use `durableStorage` . To configure durable storage see [`this`](#durable-storage) section. | `taskReport` | | `waitTillSegmentsLoad` | INSERT, REPLACE

If set, the ingest query waits for the generated segment to be loaded before exiting, else the ingest query exits without waiting. The task and live reports contain the information about the status of loading segments if this flag is set. This will ensure that any future queries made after the ingestion exits will include results from the ingestion. The drawback is that the controller task will stall till the segments are loaded. | `false` | +| `includeSegmentSource` | SELECT, INSERT, REPLACE

Controls the sources, which will be queried for results in addition to the segments present on deep storage. Can be `NONE` or `REALTIME`. If this value is `NONE`, only non-realtime (published and used) segments will be downloaded from deep storage. If this value is `REALTIME`, results will also be included from realtime tasks. | `NONE` | ## Joins diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index c7b10f245c1d..58768644bf69 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -24,6 +24,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.util.concurrent.FutureCallback; @@ -39,6 +40,7 @@ import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.data.input.StringTuple; import org.apache.druid.data.input.impl.DimensionSchema; @@ -140,6 +142,7 @@ import org.apache.druid.msq.input.stage.StageInputSlice; import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpecSlicer; +import org.apache.druid.msq.input.table.DataSegmentWithLocation; import org.apache.druid.msq.input.table.TableInputSpec; import org.apache.druid.msq.input.table.TableInputSpecSlicer; import org.apache.druid.msq.kernel.GlobalSortTargetSizeShuffleSpec; @@ -187,6 +190,7 @@ import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec; import org.apache.druid.segment.transform.TransformSpec; import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; import org.apache.druid.sql.calcite.planner.ColumnMapping; import org.apache.druid.sql.calcite.planner.ColumnMappings; import org.apache.druid.sql.calcite.rel.DruidQuery; @@ -1163,14 +1167,73 @@ private QueryKit makeQueryControllerToolKit() private DataSegmentTimelineView makeDataSegmentTimelineView() { + final SegmentSource includeSegmentSource = MultiStageQueryContext.getSegmentSources( + task.getQuerySpec() + .getQuery() + .context() + ); + + final boolean includeRealtime = SegmentSource.shouldQueryRealtimeServers(includeSegmentSource); + return (dataSource, intervals) -> { - final Collection dataSegments = + final Iterable realtimeAndHistoricalSegments; + + // Fetch the realtime segments and segments loaded on the historical. Do this first so that we don't miss any + // segment if they get handed off between the two calls. Segments loaded on historicals are deduplicated below, + // since we are only interested in realtime segments for now. + if (includeRealtime) { + realtimeAndHistoricalSegments = context.coordinatorClient().fetchServerViewSegments(dataSource, intervals); + } else { + realtimeAndHistoricalSegments = ImmutableList.of(); + } + + // Fetch all published, used segments (all non-realtime segments) from the metadata store. + final Collection publishedUsedSegments = FutureUtils.getUnchecked(context.coordinatorClient().fetchUsedSegments(dataSource, intervals), true); - if (dataSegments.isEmpty()) { + int realtimeCount = 0; + + // Deduplicate segments, giving preference to published used segments. + // We do this so that if any segments have been handed off in between the two metadata calls above, + // we directly fetch it from deep storage. + Set unifiedSegmentView = new HashSet<>(publishedUsedSegments); + + // Iterate over the realtime segments and segments loaded on the historical + for (ImmutableSegmentLoadInfo segmentLoadInfo : realtimeAndHistoricalSegments) { + ImmutableSet servers = segmentLoadInfo.getServers(); + // Filter out only realtime servers. We don't want to query historicals for now, but we can in the future. + // This check can be modified then. + Set realtimeServerMetadata + = servers.stream() + .filter(druidServerMetadata -> includeSegmentSource.getUsedServerTypes() + .contains(druidServerMetadata.getType()) + ) + .collect(Collectors.toSet()); + if (!realtimeServerMetadata.isEmpty()) { + realtimeCount += 1; + DataSegmentWithLocation dataSegmentWithLocation = new DataSegmentWithLocation( + segmentLoadInfo.getSegment(), + realtimeServerMetadata + ); + unifiedSegmentView.add(dataSegmentWithLocation); + } else { + // We don't have any segments of the required segment source, ignore the segment + } + } + + if (includeRealtime) { + log.info( + "Fetched total [%d] segments from coordinator: [%d] from metadata stoure, [%d] from server view", + unifiedSegmentView.size(), + publishedUsedSegments.size(), + realtimeCount + ); + } + + if (unifiedSegmentView.isEmpty()) { return Optional.empty(); } else { - return Optional.of(SegmentTimeline.forSegments(dataSegments)); + return Optional.of(SegmentTimeline.forSegments(unifiedSegmentView)); } }; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProvider.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProvider.java new file mode 100644 index 000000000000..d9d789e3d2ba --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProvider.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; +import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.discovery.DataServerClient; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.IOE; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.RetryUtils; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Yielder; +import org.apache.druid.java.util.common.guava.Yielders; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.counters.ChannelCounters; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; +import org.apache.druid.query.Queries; +import org.apache.druid.query.Query; +import org.apache.druid.query.QueryInterruptedException; +import org.apache.druid.query.QueryToolChest; +import org.apache.druid.query.QueryToolChestWarehouse; +import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.aggregation.MetricManipulationFn; +import org.apache.druid.query.aggregation.MetricManipulatorFns; +import org.apache.druid.query.context.DefaultResponseContext; +import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.rpc.FixedSetServiceLocator; +import org.apache.druid.rpc.RpcException; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.utils.CollectionUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; + +/** + * Class responsible for querying dataservers and retriving results for a given query. Also queries the coordinator + * to check if a segment has been handed off. + */ +public class LoadedSegmentDataProvider +{ + private static final Logger log = new Logger(LoadedSegmentDataProvider.class); + private static final int DEFAULT_NUM_TRIES = 5; + private final String dataSource; + private final ChannelCounters channelCounters; + private final ServiceClientFactory serviceClientFactory; + private final CoordinatorClient coordinatorClient; + private final ObjectMapper objectMapper; + private final QueryToolChestWarehouse warehouse; + private final ScheduledExecutorService queryCancellationExecutor; + + public LoadedSegmentDataProvider( + String dataSource, + ChannelCounters channelCounters, + ServiceClientFactory serviceClientFactory, + CoordinatorClient coordinatorClient, + ObjectMapper objectMapper, + QueryToolChestWarehouse warehouse, + ScheduledExecutorService queryCancellationExecutor + ) + { + this.dataSource = dataSource; + this.channelCounters = channelCounters; + this.serviceClientFactory = serviceClientFactory; + this.coordinatorClient = coordinatorClient; + this.objectMapper = objectMapper; + this.warehouse = warehouse; + this.queryCancellationExecutor = queryCancellationExecutor; + } + + @VisibleForTesting + DataServerClient makeDataServerClient(ServiceLocation serviceLocation) + { + return new DataServerClient(serviceClientFactory, serviceLocation, objectMapper, queryCancellationExecutor); + } + + /** + * Performs some necessary transforms to the query, so that the dataserver is able to understand it first. + * - Changing the datasource to a {@link TableDataSource} + * - Limiting the query to a single required segment with {@link Queries#withSpecificSegments(Query, List)} + *
+ * Then queries a data server and returns a {@link Yielder} for the results, retrying if needed. If a dataserver + * indicates that the segment was not found, checks with the coordinator to see if the segment was handed off. + * - If the segment was handed off, returns with a {@link DataServerQueryStatus#HANDOFF} status. + * - If the segment was not handed off, retries with the known list of servers and throws an exception if the retry + * count is exceeded. + * - If the servers could not be found, checks if the segment was handed-off. If it was, returns with a + * {@link DataServerQueryStatus#HANDOFF} status. Otherwise, throws an exception. + *
+ * Also applies {@link QueryToolChest#makePreComputeManipulatorFn(Query, MetricManipulationFn)} and reports channel + * metrics on the returned results. + * + * @param result return type for the query from the data server + * @param type of the result rows after parsing from QueryType object + */ + public Pair> fetchRowsFromDataServer( + Query query, + RichSegmentDescriptor segmentDescriptor, + Function, Sequence> mappingFunction, + Closer closer + ) throws IOException + { + final Query preparedQuery = Queries.withSpecificSegments( + query.withDataSource(new TableDataSource(dataSource)), + ImmutableList.of(segmentDescriptor) + ); + + final Set servers = segmentDescriptor.getServers(); + final FixedSetServiceLocator fixedSetServiceLocator = FixedSetServiceLocator.forDruidServerMetadata(servers); + final QueryToolChest> toolChest = warehouse.getToolChest(query); + final Function preComputeManipulatorFn = + toolChest.makePreComputeManipulatorFn(query, MetricManipulatorFns.deserializing()); + + final JavaType queryResultType = toolChest.getBaseResultType(); + final int numRetriesOnMissingSegments = preparedQuery.context().getNumRetriesOnMissingSegments(DEFAULT_NUM_TRIES); + + log.debug("Querying severs[%s] for segment[%s], retries:[%d]", servers, segmentDescriptor, numRetriesOnMissingSegments); + final ResponseContext responseContext = new DefaultResponseContext(); + + Pair> statusSequencePair; + try { + // We need to check for handoff to decide if we need to retry. Therefore, we handle it here instead of inside + // the client. + statusSequencePair = RetryUtils.retry( + () -> { + ServiceLocation serviceLocation = CollectionUtils.getOnlyElement( + fixedSetServiceLocator.locate().get().getLocations(), + serviceLocations -> { + throw DruidException.defensive("Should only have one location"); + } + ); + DataServerClient dataServerClient = makeDataServerClient(serviceLocation); + Sequence sequence = dataServerClient.run(preparedQuery, responseContext, queryResultType, closer) + .map(preComputeManipulatorFn); + final List missingSegments = getMissingSegments(responseContext); + // Only one segment is fetched, so this should be empty if it was fetched + if (missingSegments.isEmpty()) { + log.debug("Successfully fetched rows from server for segment[%s]", segmentDescriptor); + // Segment was found + Yielder yielder = closer.register( + Yielders.each(mappingFunction.apply(sequence) + .map(row -> { + channelCounters.incrementRowCount(); + return row; + })) + ); + return Pair.of(DataServerQueryStatus.SUCCESS, yielder); + } else { + Boolean wasHandedOff = checkSegmentHandoff(coordinatorClient, dataSource, segmentDescriptor); + if (Boolean.TRUE.equals(wasHandedOff)) { + log.debug("Segment[%s] was handed off.", segmentDescriptor); + return Pair.of(DataServerQueryStatus.HANDOFF, null); + } else { + log.error("Segment[%s] could not be found on data server, but segment was not handed off.", segmentDescriptor); + throw new IOE( + "Segment[%s] could not be found on data server, but segment was not handed off.", + segmentDescriptor + ); + } + } + }, + throwable -> !(throwable instanceof QueryInterruptedException && throwable.getCause() instanceof InterruptedException), + numRetriesOnMissingSegments + ); + + return statusSequencePair; + } + catch (QueryInterruptedException e) { + if (e.getCause() instanceof RpcException) { + // In the case that all the realtime servers for a segment are gone (for example, if they were scaled down), + // we would also be unable to fetch the segment. Check if the segment was handed off, just in case, instead of + // failing the query. + boolean wasHandedOff = checkSegmentHandoff(coordinatorClient, dataSource, segmentDescriptor); + if (wasHandedOff) { + log.debug("Segment[%s] was handed off.", segmentDescriptor); + return Pair.of(DataServerQueryStatus.HANDOFF, null); + } + } + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build(e, "Exception while fetching rows for query from dataservers[%s]", servers); + } + catch (Exception e) { + Throwables.propagateIfPossible(e, IOE.class); + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build(e, "Exception while fetching rows for query from dataservers[%s]", servers); + } + } + + /** + * Retreives the list of missing segments from the response context. + */ + private static List getMissingSegments(final ResponseContext responseContext) + { + List missingSegments = responseContext.getMissingSegments(); + if (missingSegments == null) { + return ImmutableList.of(); + } + return missingSegments; + } + + /** + * Queries the coordinator to check if a segment has been handed off. + *
+ * See {@link org.apache.druid.server.http.DataSourcesResource#isHandOffComplete(String, String, int, String)} + */ + private static boolean checkSegmentHandoff( + CoordinatorClient coordinatorClient, + String dataSource, + SegmentDescriptor segmentDescriptor + ) throws IOE + { + Boolean wasHandedOff; + try { + wasHandedOff = FutureUtils.get(coordinatorClient.isHandoffComplete(dataSource, segmentDescriptor), true); + } + catch (Exception e) { + throw new IOE(e, "Could not contact coordinator for segment[%s]", segmentDescriptor); + } + return Boolean.TRUE.equals(wasHandedOff); + } + + /** + * Represents the status of fetching a segment from a data server + */ + public enum DataServerQueryStatus + { + /** + * Segment was found on the data server and fetched successfully. + */ + SUCCESS, + /** + * Segment was not found on the realtime server as it has been handed off to a historical. Only returned while + * querying a realtime server. + */ + HANDOFF + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderFactory.java new file mode 100644 index 000000000000..48ed57be8701 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderFactory.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.java.util.common.RE; +import org.apache.druid.java.util.common.concurrent.ScheduledExecutors; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.counters.ChannelCounters; +import org.apache.druid.query.QueryToolChestWarehouse; +import org.apache.druid.rpc.ServiceClientFactory; + +import java.io.Closeable; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Creates new instances of {@link LoadedSegmentDataProvider} and manages the cancellation threadpool. + */ +public class LoadedSegmentDataProviderFactory implements Closeable +{ + private static final Logger log = new Logger(LoadedSegmentDataProviderFactory.class); + private static final int DEFAULT_THREAD_COUNT = 4; + private final CoordinatorClient coordinatorClient; + private final ServiceClientFactory serviceClientFactory; + private final ObjectMapper objectMapper; + private final QueryToolChestWarehouse warehouse; + private final ScheduledExecutorService queryCancellationExecutor; + + public LoadedSegmentDataProviderFactory( + CoordinatorClient coordinatorClient, + ServiceClientFactory serviceClientFactory, + ObjectMapper objectMapper, + QueryToolChestWarehouse warehouse + ) + { + this.coordinatorClient = coordinatorClient; + this.serviceClientFactory = serviceClientFactory; + this.objectMapper = objectMapper; + this.warehouse = warehouse; + this.queryCancellationExecutor = ScheduledExecutors.fixed(DEFAULT_THREAD_COUNT, "query-cancellation-executor"); + } + + public LoadedSegmentDataProvider createLoadedSegmentDataProvider( + String dataSource, + ChannelCounters channelCounters + ) + { + return new LoadedSegmentDataProvider( + dataSource, + channelCounters, + serviceClientFactory, + coordinatorClient, + objectMapper, + warehouse, + queryCancellationExecutor + ); + } + + @Override + public void close() + { + // Wait for all query cancellations to be complete. + log.info("Waiting for any data server queries to be canceled."); + queryCancellationExecutor.shutdown(); + try { + if (!queryCancellationExecutor.awaitTermination(1, TimeUnit.MINUTES)) { + log.error("Unable to cancel all ongoing queries."); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RE(e); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentSource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentSource.java new file mode 100644 index 000000000000..22f3a5df973c --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentSource.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableSet; +import org.apache.druid.server.coordination.ServerType; + +import java.util.Set; + +/** + * Decides the types of data servers contacted by MSQ tasks to fetch results. + */ +public enum SegmentSource +{ + /** + * Include only segments from deep storage. + */ + NONE(ImmutableSet.of()), + /** + * Include segments from realtime tasks as well as segments from deep storage. + */ + REALTIME(ImmutableSet.of(ServerType.REALTIME, ServerType.INDEXER_EXECUTOR)); + + /** + * The type of dataservers (if any) to include. This does not include segments queried from deep storage, which are + * always included in queries. + */ + private final Set usedServerTypes; + + SegmentSource(Set usedServerTypes) + { + this.usedServerTypes = usedServerTypes; + } + + public Set getUsedServerTypes() + { + return usedServerTypes; + } + + /** + * Whether realtime servers should be included for the segmentSource. + */ + public static boolean shouldQueryRealtimeServers(SegmentSource segmentSource) + { + return REALTIME.equals(segmentSource); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java index d017feb099fb..a3d4fde6c1a5 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java @@ -74,6 +74,7 @@ public interface WorkerContext DruidNode selfNode(); Bouncer processorBouncer(); + LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory(); default File tempDir(int stageNumber, String id) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index d522c3a7f169..6ee45bc158e8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -295,6 +295,7 @@ public Optional runTask(final Closer closer) throws Exception { this.controllerClient = context.makeControllerClient(task.getControllerTaskId()); closer.register(controllerClient::close); + closer.register(context.loadedSegmentDataProviderFactory()); context.registerWorker(this, closer); // Uses controllerClient, so must be called after that is initialized this.workerClient = new ExceptionWrappingWorkerClient(context.makeWorkerClient()); @@ -1103,7 +1104,7 @@ private void makeInputSliceReader() .put( SegmentsInputSlice.class, new SegmentsInputSliceReader( - frameContext.dataSegmentProvider(), + frameContext, MultiStageQueryContext.isReindex(QueryContext.of(task().getContext())) ) ) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index c35832992f93..d522a8a7d88f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.querykit.DataSegmentProvider; @@ -38,17 +39,20 @@ public class IndexerFrameContext implements FrameContext private final IndexIO indexIO; private final DataSegmentProvider dataSegmentProvider; private final WorkerMemoryParameters memoryParameters; + private final LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory; public IndexerFrameContext( IndexerWorkerContext context, IndexIO indexIO, DataSegmentProvider dataSegmentProvider, + LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory, WorkerMemoryParameters memoryParameters ) { this.context = context; this.indexIO = indexIO; this.dataSegmentProvider = dataSegmentProvider; + this.loadedSegmentDataProviderFactory = loadedSegmentDataProviderFactory; this.memoryParameters = memoryParameters; } @@ -76,6 +80,13 @@ public DataSegmentProvider dataSegmentProvider() return dataSegmentProvider; } + @Override + public LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory() + { + return loadedSegmentDataProviderFactory; + } + + @Override public File tempDir() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java index 43d067dd6c90..709b019891f0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java @@ -27,12 +27,14 @@ import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.guice.annotations.EscalatedGlobal; import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; import org.apache.druid.indexing.common.TaskToolbox; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.exec.TaskDataSegmentProvider; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; @@ -43,6 +45,7 @@ import org.apache.druid.msq.indexing.client.WorkerChatHandler; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.query.QueryToolChestWarehouse; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.ServiceLocations; import org.apache.druid.rpc.ServiceLocator; @@ -68,6 +71,7 @@ public class IndexerWorkerContext implements WorkerContext private final Injector injector; private final IndexIO indexIO; private final TaskDataSegmentProvider dataSegmentProvider; + private final LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory; private final ServiceClientFactory clientFactory; @GuardedBy("this") @@ -81,6 +85,7 @@ public IndexerWorkerContext( final Injector injector, final IndexIO indexIO, final TaskDataSegmentProvider dataSegmentProvider, + final LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory, final ServiceClientFactory clientFactory ) { @@ -88,6 +93,7 @@ public IndexerWorkerContext( this.injector = injector; this.indexIO = indexIO; this.dataSegmentProvider = dataSegmentProvider; + this.loadedSegmentDataProviderFactory = loadedSegmentDataProviderFactory; this.clientFactory = clientFactory; } @@ -99,12 +105,24 @@ public static IndexerWorkerContext createProductionInstance(final TaskToolbox to .manufacturate(new File(toolbox.getIndexingTmpDir(), "segment-fetch")); final ServiceClientFactory serviceClientFactory = injector.getInstance(Key.get(ServiceClientFactory.class, EscalatedGlobal.class)); + final ObjectMapper smileMapper = injector.getInstance(Key.get(ObjectMapper.class, Smile.class)); + final QueryToolChestWarehouse warehouse = injector.getInstance(QueryToolChestWarehouse.class); return new IndexerWorkerContext( toolbox, injector, indexIO, - new TaskDataSegmentProvider(toolbox.getCoordinatorClient(), segmentCacheManager, indexIO), + new TaskDataSegmentProvider( + toolbox.getCoordinatorClient(), + segmentCacheManager, + indexIO + ), + new LoadedSegmentDataProviderFactory( + toolbox.getCoordinatorClient(), + serviceClientFactory, + smileMapper, + warehouse + ), serviceClientFactory ); } @@ -227,6 +245,7 @@ public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) this, indexIO, dataSegmentProvider, + loadedSegmentDataProviderFactory, WorkerMemoryParameters.createProductionInstanceForWorker(injector, queryDef, stageNumber) ); } @@ -249,6 +268,12 @@ public Bouncer processorBouncer() return injector.getInstance(Bouncer.class); } + @Override + public LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory() + { + return loadedSegmentDataProviderFactory; + } + private synchronized OverlordClient makeOverlordClient() { if (overlordClient == null) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java index 3dbd3da0a026..084d58e217d6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java @@ -42,6 +42,7 @@ import org.apache.druid.msq.input.NilInputSource; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.ReadableInputs; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; import org.apache.druid.msq.input.table.SegmentWithDescriptor; import org.apache.druid.msq.util.DimensionSchemaUtils; import org.apache.druid.segment.RowBasedSegment; @@ -159,7 +160,8 @@ private static Iterator inputSourceSegmentIterator( ); return new SegmentWithDescriptor( () -> ResourceHolder.fromCloseable(segment), - segmentId.toDescriptor() + null, + new RichSegmentDescriptor(segmentId.toDescriptor(), null, null) ); } ); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/inline/InlineInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/inline/InlineInputSliceReader.java index 143fb49692f2..25f06c7cd40c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/inline/InlineInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/inline/InlineInputSliceReader.java @@ -27,9 +27,9 @@ import org.apache.druid.msq.input.InputSliceReader; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.ReadableInputs; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; import org.apache.druid.msq.input.table.SegmentWithDescriptor; import org.apache.druid.query.InlineDataSource; -import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.segment.InlineSegmentWrangler; import org.apache.druid.segment.SegmentWrangler; import org.apache.druid.timeline.SegmentId; @@ -43,7 +43,8 @@ public class InlineInputSliceReader implements InputSliceReader { public static final String SEGMENT_ID = "__inline"; - private static final SegmentDescriptor DUMMY_SEGMENT_DESCRIPTOR = SegmentId.dummy(SEGMENT_ID).toDescriptor(); + private static final RichSegmentDescriptor DUMMY_SEGMENT_DESCRIPTOR + = new RichSegmentDescriptor(SegmentId.dummy(SEGMENT_ID).toDescriptor(), null, null); private final SegmentWrangler segmentWrangler; @@ -74,6 +75,7 @@ public ReadableInputs attach( segment -> ReadableInput.segment( new SegmentWithDescriptor( () -> ResourceHolder.fromCloseable(segment), + null, DUMMY_SEGMENT_DESCRIPTOR ) ) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/lookup/LookupInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/lookup/LookupInputSliceReader.java index 648527ce0061..b601b043ac13 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/lookup/LookupInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/lookup/LookupInputSliceReader.java @@ -29,6 +29,7 @@ import org.apache.druid.msq.input.InputSliceReader; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.ReadableInputs; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; import org.apache.druid.msq.input.table.SegmentWithDescriptor; import org.apache.druid.query.LookupDataSource; import org.apache.druid.segment.Segment; @@ -99,7 +100,8 @@ public ReadableInputs attach( return ResourceHolder.fromCloseable(segment); }, - SegmentId.dummy(lookupName).toDescriptor() + null, + new RichSegmentDescriptor(SegmentId.dummy(lookupName).toDescriptor(), null, null) ) ) ) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/DataSegmentWithLocation.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/DataSegmentWithLocation.java new file mode 100644 index 000000000000..0e83e9c3edee --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/DataSegmentWithLocation.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.input.table; + +import com.fasterxml.jackson.annotation.JacksonInject; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.common.base.Preconditions; +import org.apache.druid.jackson.CommaListJoinDeserializer; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.timeline.CompactionState; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.ShardSpec; +import org.joda.time.Interval; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Data segment including the locations which contain the segment. Used if MSQ needs to fetch the segment from a server + * instead of from deep storage. + */ +public class DataSegmentWithLocation extends DataSegment +{ + private final Set servers; + + @JsonCreator + public DataSegmentWithLocation( + @JsonProperty("dataSource") String dataSource, + @JsonProperty("interval") Interval interval, + @JsonProperty("version") String version, + // use `Map` *NOT* `LoadSpec` because we want to do lazy materialization to prevent dependency pollution + @JsonProperty("loadSpec") @Nullable Map loadSpec, + @JsonProperty("dimensions") + @JsonDeserialize(using = CommaListJoinDeserializer.class) + @Nullable + List dimensions, + @JsonProperty("metrics") + @JsonDeserialize(using = CommaListJoinDeserializer.class) + @Nullable + List metrics, + @JsonProperty("shardSpec") @Nullable ShardSpec shardSpec, + @JsonProperty("lastCompactionState") @Nullable CompactionState lastCompactionState, + @JsonProperty("binaryVersion") Integer binaryVersion, + @JsonProperty("size") long size, + @JsonProperty("servers") Set servers, + @JacksonInject PruneSpecsHolder pruneSpecsHolder + ) + { + super(dataSource, interval, version, loadSpec, dimensions, metrics, shardSpec, lastCompactionState, binaryVersion, size, pruneSpecsHolder); + this.servers = Preconditions.checkNotNull(servers, "servers"); + } + + public DataSegmentWithLocation( + DataSegment dataSegment, + Set servers + ) + { + super( + dataSegment.getDataSource(), + dataSegment.getInterval(), + dataSegment.getVersion(), + dataSegment.getLoadSpec(), + dataSegment.getDimensions(), + dataSegment.getMetrics(), + dataSegment.getShardSpec(), + dataSegment.getBinaryVersion(), + dataSegment.getSize() + ); + this.servers = servers; + } + + @JsonProperty("servers") + public Set getServers() + { + return servers; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/RichSegmentDescriptor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/RichSegmentDescriptor.java index 3ca48ef9cbdf..04e4e601b073 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/RichSegmentDescriptor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/RichSegmentDescriptor.java @@ -23,36 +23,54 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.utils.CollectionUtils; import org.joda.time.Interval; import javax.annotation.Nullable; import java.util.Objects; +import java.util.Set; /** - * Like {@link SegmentDescriptor}, but provides both the full interval and the clipped interval for a segment. - * (SegmentDescriptor only provides the clipped interval.) - * + * Like {@link SegmentDescriptor}, but provides both the full interval and the clipped interval for a segment + * (SegmentDescriptor only provides the clipped interval.), as well as the metadata of the servers it is loaded on. + *
* To keep the serialized form lightweight, the full interval is only serialized if it is different from the * clipped interval. - * + *
* It is possible to deserialize this class as {@link SegmentDescriptor}. However, going the other direction is - * not a good idea, because the {@link #fullInterval} will not end up being set correctly. + * not a good idea, because the {@link #fullInterval} and {@link #servers} will not end up being set correctly. */ public class RichSegmentDescriptor extends SegmentDescriptor { @Nullable private final Interval fullInterval; + private final Set servers; public RichSegmentDescriptor( final Interval fullInterval, final Interval interval, final String version, - final int partitionNumber + final int partitionNumber, + final Set servers ) { super(interval, version, partitionNumber); this.fullInterval = interval.equals(Preconditions.checkNotNull(fullInterval, "fullInterval")) ? null : fullInterval; + this.servers = servers; + } + + public RichSegmentDescriptor( + SegmentDescriptor segmentDescriptor, + @Nullable Interval fullInterval, + Set servers + ) + { + super(segmentDescriptor.getInterval(), segmentDescriptor.getVersion(), segmentDescriptor.getPartitionNumber()); + this.fullInterval = fullInterval; + this.servers = servers; } @JsonCreator @@ -60,17 +78,33 @@ static RichSegmentDescriptor fromJson( @JsonProperty("fi") @Nullable final Interval fullInterval, @JsonProperty("itvl") final Interval interval, @JsonProperty("ver") final String version, - @JsonProperty("part") final int partitionNumber + @JsonProperty("part") final int partitionNumber, + @JsonProperty("servers") @Nullable final Set servers ) { return new RichSegmentDescriptor( fullInterval != null ? fullInterval : interval, interval, version, - partitionNumber + partitionNumber, + servers == null ? ImmutableSet.of() : servers ); } + /** + * Returns true if the location the segment is loaded is available, and false if it is not. + */ + public boolean isLoadedOnServer() + { + return !CollectionUtils.isNullOrEmpty(getServers()); + } + + @JsonProperty("servers") + public Set getServers() + { + return servers; + } + public Interval getFullInterval() { return fullInterval == null ? getInterval() : fullInterval; @@ -97,13 +131,13 @@ public boolean equals(Object o) return false; } RichSegmentDescriptor that = (RichSegmentDescriptor) o; - return Objects.equals(fullInterval, that.fullInterval); + return Objects.equals(fullInterval, that.fullInterval) && Objects.equals(servers, that.servers); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), fullInterval); + return Objects.hash(super.hashCode(), fullInterval, servers); } @Override @@ -111,6 +145,7 @@ public String toString() { return "RichSegmentDescriptor{" + "fullInterval=" + (fullInterval == null ? getInterval() : fullInterval) + + ", servers=" + getServers() + ", interval=" + getInterval() + ", version='" + getVersion() + '\'' + ", partitionNumber=" + getPartitionNumber() + diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java index 020b9f2a5bb0..137129ed338b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentWithDescriptor.java @@ -21,10 +21,19 @@ import com.google.common.base.Preconditions; import org.apache.druid.collections.ResourceHolder; -import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Yielder; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.msq.exec.LoadedSegmentDataProvider; +import org.apache.druid.query.Query; import org.apache.druid.segment.Segment; +import javax.annotation.Nullable; +import java.io.IOException; import java.util.Objects; +import java.util.function.Function; import java.util.function.Supplier; /** @@ -33,30 +42,36 @@ public class SegmentWithDescriptor { private final Supplier> segmentSupplier; - private final SegmentDescriptor descriptor; + @Nullable + private final LoadedSegmentDataProvider loadedSegmentDataProvider; + private final RichSegmentDescriptor descriptor; /** * Create a new instance. * - * @param segmentSupplier supplier of a {@link ResourceHolder} of segment. The {@link ResourceHolder#close()} logic - * must include a delegated call to {@link Segment#close()}. - * @param descriptor segment descriptor + * @param segmentSupplier supplier of a {@link ResourceHolder} of segment. The {@link ResourceHolder#close()} + * logic must include a delegated call to {@link Segment#close()}. + * @param loadedSegmentDataProvider {@link LoadedSegmentDataProvider} which fetches the corresponding results from a + * data server where the segment is loaded. The call will fetch the + * @param descriptor segment descriptor */ public SegmentWithDescriptor( final Supplier> segmentSupplier, - final SegmentDescriptor descriptor + final @Nullable LoadedSegmentDataProvider loadedSegmentDataProvider, + final RichSegmentDescriptor descriptor ) { this.segmentSupplier = Preconditions.checkNotNull(segmentSupplier, "segment"); + this.loadedSegmentDataProvider = loadedSegmentDataProvider; this.descriptor = Preconditions.checkNotNull(descriptor, "descriptor"); } /** * The physical segment. - * + *
* Named "getOrLoad" because the segment may be generated by a lazy supplier. In this case, the segment is acquired * as part of the call to this method. - * + *
* It is not necessary to call {@link Segment#close()} on the returned segment. Calling {@link ResourceHolder#close()} * is enough. */ @@ -65,10 +80,23 @@ public ResourceHolder getOrLoad() return segmentSupplier.get(); } + public Pair> fetchRowsFromDataServer( + Query query, + Function, Sequence> mappingFunction, + Closer closer + ) throws IOException + { + if (loadedSegmentDataProvider == null) { + throw DruidException.defensive("loadedSegmentDataProvider was null. Fetching segments from servers is not " + + "supported for segment[%s]", descriptor); + } + return loadedSegmentDataProvider.fetchRowsFromDataServer(query, descriptor, mappingFunction, closer); + } + /** * The segment descriptor associated with this physical segment. */ - public SegmentDescriptor getDescriptor() + public RichSegmentDescriptor getDescriptor() { return descriptor; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java index 5334c4cb2abd..8bc67dbb4e88 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/SegmentsInputSliceReader.java @@ -23,10 +23,12 @@ import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterNames; import org.apache.druid.msq.counters.CounterTracker; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSliceReader; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.ReadableInputs; +import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.timeline.SegmentId; @@ -40,11 +42,13 @@ public class SegmentsInputSliceReader implements InputSliceReader { private final DataSegmentProvider dataSegmentProvider; + private final LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory; private final boolean isReindex; - public SegmentsInputSliceReader(final DataSegmentProvider dataSegmentProvider, final boolean isReindex) + public SegmentsInputSliceReader(final FrameContext frameContext, final boolean isReindex) { - this.dataSegmentProvider = dataSegmentProvider; + this.dataSegmentProvider = frameContext.dataSegmentProvider(); + this.loadedSegmentDataProviderFactory = frameContext.loadedSegmentDataProviderFactory(); this.isReindex = isReindex; } @@ -94,6 +98,7 @@ private Iterator dataSegmentIterator( return new SegmentWithDescriptor( dataSegmentProvider.fetchSegment(segmentId, channelCounters, isReindex), + descriptor.isLoadedOnServer() ? loadedSegmentDataProviderFactory.createLoadedSegmentDataProvider(dataSource, channelCounters) : null, descriptor ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java index 37a97d33be5c..91f2e681e1ea 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java @@ -166,7 +166,8 @@ public RichSegmentDescriptor toRichSegmentDescriptor() segment.getInterval(), interval, segment.getVersion(), - segment.getShardSpec().getPartitionNum() + segment.getShardSpec().getPartitionNum(), + segment instanceof DataSegmentWithLocation ? ((DataSegmentWithLocation) segment).getServers() : null ); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java index 2339ac5537a0..49871cecc1d4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.kernel; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; @@ -43,6 +44,7 @@ public interface FrameContext RowIngestionMeters rowIngestionMeters(); DataSegmentProvider dataSegmentProvider(); + LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory(); File tempDir(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java index 3974b7e1e1d6..f67f30d0c5c6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessor.java @@ -82,7 +82,12 @@ public ReturnOrAwait runIncrementally(final IntSet readableInputs) throw final ReturnOrAwait retVal; if (baseInput.hasSegment()) { - retVal = runWithSegment(baseInput.getSegment()); + SegmentWithDescriptor segment = baseInput.getSegment(); + if (segment.getDescriptor().isLoadedOnServer()) { + retVal = runWithLoadedSegment(baseInput.getSegment()); + } else { + retVal = runWithSegment(baseInput.getSegment()); + } } else { retVal = runWithInputChannel(baseInput.getChannel(), baseInput.getChannelFrameReader()); } @@ -105,6 +110,7 @@ protected FrameWriterFactory getFrameWriterFactory() } protected abstract ReturnOrAwait runWithSegment(SegmentWithDescriptor segment) throws IOException; + protected abstract ReturnOrAwait runWithLoadedSegment(SegmentWithDescriptor segment) throws IOException; protected abstract ReturnOrAwait runWithInputChannel( ReadableFrameChannel inputChannel, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSegmentProvider.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSegmentProvider.java index 0e931c7f8ef0..91ee4a487885 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSegmentProvider.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSegmentProvider.java @@ -32,7 +32,7 @@ public interface DataSegmentProvider * Returns a supplier that fetches the segment corresponding to the provided segmentId from deep storage. The segment * is not actually fetched until you call {@link Supplier#get()}. Once you call this, make sure to also call * {@link ResourceHolder#close()}. - * + *
* It is not necessary to call {@link ResourceHolder#close()} if you never call {@link Supplier#get()}. */ Supplier> fetchSegment( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java index 63f5ad6650ae..1e9eedc4c436 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleFrameProcessor.java @@ -33,11 +33,14 @@ import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Unit; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.LoadedSegmentDataProvider; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.table.SegmentWithDescriptor; import org.apache.druid.msq.querykit.BaseLeafFrameProcessor; @@ -62,6 +65,7 @@ */ public class GroupByPreShuffleFrameProcessor extends BaseLeafFrameProcessor { + private static final Logger log = new Logger(GroupByPreShuffleFrameProcessor.class); private final GroupByQuery query; private final GroupingEngine groupingEngine; private final ColumnSelectorFactory frameWriterColumnSelectorFactory; @@ -95,6 +99,29 @@ public GroupByPreShuffleFrameProcessor( ); } + @Override + protected ReturnOrAwait runWithLoadedSegment(SegmentWithDescriptor segment) throws IOException + { + if (resultYielder == null) { + Pair> statusSequencePair = + segment.fetchRowsFromDataServer(groupingEngine.prepareGroupByQuery(query), Function.identity(), closer); + if (LoadedSegmentDataProvider.DataServerQueryStatus.HANDOFF.equals(statusSequencePair.lhs)) { + log.info("Segment[%s] was handed off, falling back to fetching the segment from deep storage.", + segment.getDescriptor()); + return runWithSegment(segment); + } + resultYielder = statusSequencePair.rhs; + } + + populateFrameWriterAndFlushIfNeeded(); + + if (resultYielder == null || resultYielder.isDone()) { + return ReturnOrAwait.returnObject(Unit.instance()); + } else { + return ReturnOrAwait.runAgain(); + } + } + @Override protected ReturnOrAwait runWithSegment(final SegmentWithDescriptor segment) throws IOException { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java index 99ea8037b7bb..1541d314f215 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessor.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.collections.ResourceHolder; @@ -40,20 +41,26 @@ import org.apache.druid.frame.write.InvalidNullByteException; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Unit; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.LoadedSegmentDataProvider; import org.apache.druid.msq.input.ParseExceptionUtils; import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.external.ExternalSegment; import org.apache.druid.msq.input.table.SegmentWithDescriptor; import org.apache.druid.msq.querykit.BaseLeafFrameProcessor; import org.apache.druid.msq.querykit.QueryKitUtils; +import org.apache.druid.query.IterableRowsCursorHelper; import org.apache.druid.query.filter.Filter; import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.query.scan.ScanResultValue; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.SpecificSegmentSpec; import org.apache.druid.segment.ColumnSelectorFactory; @@ -65,11 +72,13 @@ import org.apache.druid.segment.StorageAdapter; import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.VirtualColumns; +import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.filter.Filters; import org.apache.druid.timeline.SegmentId; import org.joda.time.Interval; import javax.annotation.Nullable; +import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -82,8 +91,10 @@ */ public class ScanQueryFrameProcessor extends BaseLeafFrameProcessor { + private static final Logger log = new Logger(ScanQueryFrameProcessor.class); private final ScanQuery query; private final AtomicLong runningCountForLimit; + private final ObjectMapper jsonMapper; private final SettableLongVirtualColumn partitionBoostVirtualColumn; private final VirtualColumns frameWriterVirtualColumns; private final Closer closer = Closer.create(); @@ -112,6 +123,7 @@ public ScanQueryFrameProcessor( ); this.query = query; this.runningCountForLimit = runningCountForLimit; + this.jsonMapper = jsonMapper; this.partitionBoostVirtualColumn = new SettableLongVirtualColumn(QueryKitUtils.PARTITION_BOOST_COLUMN); final List frameWriterVirtualColumns = new ArrayList<>(); @@ -152,6 +164,63 @@ public void cleanup() throws IOException closer.close(); } + public static Sequence mappingFunction(Sequence inputSequence) + { + return inputSequence.flatMap(resultRow -> { + List> events = (List>) resultRow.getEvents(); + return Sequences.simple(events); + }).map(List::toArray); + } + + @Override + protected ReturnOrAwait runWithLoadedSegment(final SegmentWithDescriptor segment) throws IOException + { + if (cursor == null) { + final Pair> statusSequencePair = + segment.fetchRowsFromDataServer( + query, + ScanQueryFrameProcessor::mappingFunction, + closer + ); + if (LoadedSegmentDataProvider.DataServerQueryStatus.HANDOFF.equals(statusSequencePair.lhs)) { + log.info("Segment[%s] was handed off, falling back to fetching the segment from deep storage.", + segment.getDescriptor()); + return runWithSegment(segment); + } + + RowSignature rowSignature = ScanQueryKit.getAndValidateSignature(query, jsonMapper); + Pair cursorFromIterable = IterableRowsCursorHelper.getCursorFromYielder( + statusSequencePair.rhs, + rowSignature + ); + + closer.register(cursorFromIterable.rhs); + final Yielder cursorYielder = Yielders.each(Sequences.simple(ImmutableList.of(cursorFromIterable.lhs))); + + if (cursorYielder.isDone()) { + // No cursors! + cursorYielder.close(); + return ReturnOrAwait.returnObject(Unit.instance()); + } else { + final long rowsFlushed = setNextCursor(cursorYielder.get(), null); + assert rowsFlushed == 0; // There's only ever one cursor when running with a segment + closer.register(cursorYielder); + } + } + + populateFrameWriterAndFlushIfNeededWithExceptionHandling(); + + if (cursor.isDone()) { + flushFrameWriter(); + } + + if (cursor.isDone() && (frameWriter == null || frameWriter.getNumRows() == 0)) { + return ReturnOrAwait.returnObject(Unit.instance()); + } else { + return ReturnOrAwait.runAgain(); + } + } + @Override protected ReturnOrAwait runWithSegment(final SegmentWithDescriptor segment) throws IOException { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 98dcd471d0fe..6e477d0c364b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -27,6 +27,7 @@ import com.opencsv.RFC4180ParserBuilder; import org.apache.druid.msq.exec.ClusterStatisticsMergeMode; import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.indexing.destination.MSQSelectDestination; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.sql.MSQMode; @@ -90,6 +91,9 @@ public class MultiStageQueryContext public static final String CTX_FINALIZE_AGGREGATIONS = "finalizeAggregations"; private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true; + public static final String CTX_INCLUDE_SEGMENT_SOURCE = "includeSegmentSource"; + public static final SegmentSource DEFAULT_INCLUDE_SEGMENT_SOURCE = SegmentSource.NONE; + public static final String CTX_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage"; private static final boolean DEFAULT_DURABLE_SHUFFLE_STORAGE = false; public static final String CTX_SELECT_DESTINATION = "selectDestination"; @@ -191,6 +195,15 @@ public static boolean isFinalizeAggregations(final QueryContext queryContext) ); } + public static SegmentSource getSegmentSources(final QueryContext queryContext) + { + return queryContext.getEnum( + CTX_INCLUDE_SEGMENT_SOURCE, + SegmentSource.class, + DEFAULT_INCLUDE_SEGMENT_SOURCE + ); + } + public static WorkerAssignmentStrategy getAssignmentStrategy(final QueryContext queryContext) { return QueryContexts.getAsEnum( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderTest.java new file mode 100644 index 000000000000..6c6ad1b3fa9a --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/LoadedSegmentDataProviderTest.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.Futures; +import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.discovery.DataServerClient; +import org.apache.druid.discovery.DruidServiceTestUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.IOE; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.guava.Yielder; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.msq.counters.ChannelCounters; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; +import org.apache.druid.msq.querykit.InputNumberDataSource; +import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessor; +import org.apache.druid.query.MapQueryToolChestWarehouse; +import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryInterruptedException; +import org.apache.druid.query.QueryToolChest; +import org.apache.druid.query.QueryToolChestWarehouse; +import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.query.scan.ScanQueryQueryToolChest; +import org.apache.druid.query.scan.ScanResultValue; +import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; +import org.apache.druid.rpc.RpcException; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +import java.io.IOException; +import java.util.List; + +import static org.apache.druid.query.Druids.newScanQueryBuilder; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@RunWith(MockitoJUnitRunner.class) +public class LoadedSegmentDataProviderTest +{ + private static final String DATASOURCE1 = "dataSource1"; + private static final DruidServerMetadata DRUID_SERVER_1 = new DruidServerMetadata( + "name1", + "host1:5050", + null, + 100L, + ServerType.REALTIME, + "tier1", + 0 + ); + private static final RichSegmentDescriptor SEGMENT_1 = new RichSegmentDescriptor( + Intervals.of("2003/2004"), + Intervals.of("2003/2004"), + "v1", + 1, + ImmutableSet.of(DRUID_SERVER_1) + ); + private DataServerClient dataServerClient; + private CoordinatorClient coordinatorClient; + private ScanResultValue scanResultValue; + private ScanQuery query; + private LoadedSegmentDataProvider target; + + @Before + public void setUp() + { + dataServerClient = mock(DataServerClient.class); + coordinatorClient = mock(CoordinatorClient.class); + scanResultValue = new ScanResultValue( + null, + ImmutableList.of(), + ImmutableList.of( + ImmutableList.of("abc", "123"), + ImmutableList.of("ghi", "456"), + ImmutableList.of("xyz", "789") + ) + ); + query = newScanQueryBuilder() + .dataSource(new InputNumberDataSource(1)) + .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2003/2004")))) + .columns("__time", "cnt", "dim1", "dim2", "m1", "m2", "unique_dim1") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(ImmutableMap.of(QueryContexts.NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, 1)) + .build(); + QueryToolChestWarehouse queryToolChestWarehouse = new MapQueryToolChestWarehouse( + ImmutableMap., QueryToolChest>builder() + .put(ScanQuery.class, new ScanQueryQueryToolChest(null, null)) + .build() + ); + target = spy( + new LoadedSegmentDataProvider( + DATASOURCE1, + new ChannelCounters(), + mock(ServiceClientFactory.class), + coordinatorClient, + DruidServiceTestUtils.newJsonMapper(), + queryToolChestWarehouse, + Execs.scheduledSingleThreaded("query-cancellation-executor") + ) + ); + doReturn(dataServerClient).when(target).makeDataServerClient(any()); + } + + @Test + public void testFetchRowsFromServer() throws IOException + { + doReturn(Sequences.simple(ImmutableList.of(scanResultValue))).when(dataServerClient).run(any(), any(), any(), any()); + + Pair> dataServerQueryStatusYielderPair = target.fetchRowsFromDataServer( + query, + SEGMENT_1, + ScanQueryFrameProcessor::mappingFunction, + Closer.create() + ); + + Assert.assertEquals(LoadedSegmentDataProvider.DataServerQueryStatus.SUCCESS, dataServerQueryStatusYielderPair.lhs); + List> events = (List>) scanResultValue.getEvents(); + Yielder yielder = dataServerQueryStatusYielderPair.rhs; + events.forEach( + event -> { + Assert.assertArrayEquals(event.toArray(), yielder.get()); + yielder.next(null); + } + ); + } + + @Test + public void testHandoff() throws IOException + { + doAnswer(invocation -> { + ResponseContext responseContext = invocation.getArgument(1); + responseContext.addMissingSegments(ImmutableList.of(SEGMENT_1)); + return Sequences.empty(); + }).when(dataServerClient).run(any(), any(), any(), any()); + doReturn(Futures.immediateFuture(Boolean.TRUE)).when(coordinatorClient).isHandoffComplete(DATASOURCE1, SEGMENT_1); + + Pair> dataServerQueryStatusYielderPair = target.fetchRowsFromDataServer( + query, + SEGMENT_1, + ScanQueryFrameProcessor::mappingFunction, + Closer.create() + ); + + Assert.assertEquals(LoadedSegmentDataProvider.DataServerQueryStatus.HANDOFF, dataServerQueryStatusYielderPair.lhs); + Assert.assertNull(dataServerQueryStatusYielderPair.rhs); + } + + @Test + public void testServerNotFoundWithoutHandoffShouldThrowException() + { + doThrow( + new QueryInterruptedException(new RpcException("Could not connect to server")) + ).when(dataServerClient).run(any(), any(), any(), any()); + + doReturn(Futures.immediateFuture(Boolean.FALSE)).when(coordinatorClient).isHandoffComplete(DATASOURCE1, SEGMENT_1); + + ScanQuery queryWithRetry = query.withOverriddenContext(ImmutableMap.of(QueryContexts.NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, 3)); + + Assert.assertThrows(DruidException.class, () -> + target.fetchRowsFromDataServer( + queryWithRetry, + SEGMENT_1, + ScanQueryFrameProcessor::mappingFunction, + Closer.create() + ) + ); + + verify(dataServerClient, times(3)).run(any(), any(), any(), any()); + } + + @Test + public void testServerNotFoundButHandoffShouldReturnWithStatus() throws IOException + { + doThrow( + new QueryInterruptedException(new RpcException("Could not connect to server")) + ).when(dataServerClient).run(any(), any(), any(), any()); + + doReturn(Futures.immediateFuture(Boolean.TRUE)).when(coordinatorClient).isHandoffComplete(DATASOURCE1, SEGMENT_1); + + Pair> dataServerQueryStatusYielderPair = target.fetchRowsFromDataServer( + query, + SEGMENT_1, + ScanQueryFrameProcessor::mappingFunction, + Closer.create() + ); + + Assert.assertEquals(LoadedSegmentDataProvider.DataServerQueryStatus.HANDOFF, dataServerQueryStatusYielderPair.lhs); + Assert.assertNull(dataServerQueryStatusYielderPair.rhs); + } + + @Test + public void testQueryFail() + { + doAnswer(invocation -> { + ResponseContext responseContext = invocation.getArgument(1); + responseContext.addMissingSegments(ImmutableList.of(SEGMENT_1)); + return Sequences.empty(); + }).when(dataServerClient).run(any(), any(), any(), any()); + doReturn(Futures.immediateFuture(Boolean.FALSE)).when(coordinatorClient).isHandoffComplete(DATASOURCE1, SEGMENT_1); + + Assert.assertThrows(IOE.class, () -> + target.fetchRowsFromDataServer( + query, + SEGMENT_1, + ScanQueryFrameProcessor::mappingFunction, + Closer.create() + ) + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQLoadedSegmentTests.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQLoadedSegmentTests.java new file mode 100644 index 000000000000..b2c07e267e4c --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQLoadedSegmentTests.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.client.ImmutableSegmentLoadInfo; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.guava.Yielders; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.indexing.MSQTuningConfig; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.test.MSQTestBase; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.apache.druid.sql.calcite.filtration.Filtration; +import org.apache.druid.sql.calcite.planner.ColumnMapping; +import org.apache.druid.sql.calcite.planner.ColumnMappings; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.LinearShardSpec; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; + +public class MSQLoadedSegmentTests extends MSQTestBase +{ + public static final Map REALTIME_QUERY_CTX = + ImmutableMap.builder() + .putAll(DEFAULT_MSQ_CONTEXT) + .put(MultiStageQueryContext.CTX_INCLUDE_SEGMENT_SOURCE, SegmentSource.REALTIME.name()) + .build(); + public static final DataSegment LOADED_SEGMENT_1 = + DataSegment.builder() + .dataSource(CalciteTests.DATASOURCE1) + .interval(Intervals.of("2003-01-01T00:00:00.000Z/2004-01-01T00:00:00.000Z")) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .size(0) + .build(); + + public static final DruidServerMetadata DATA_SERVER_1 = new DruidServerMetadata( + "TestDataServer", + "hostName:9092", + null, + 2, + ServerType.REALTIME, + "tier1", + 2 + ); + + @Before + public void setUp() + { + loadedSegmentsMetadata.add(new ImmutableSegmentLoadInfo(LOADED_SEGMENT_1, ImmutableSet.of(DATA_SERVER_1))); + } + + @Test + public void testSelectWithLoadedSegmentsOnFoo() throws IOException + { + RowSignature resultSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("dim1", ColumnType.STRING) + .build(); + + doReturn( + Pair.of( + LoadedSegmentDataProvider.DataServerQueryStatus.SUCCESS, + Yielders.each( + Sequences.simple( + ImmutableList.of( + new Object[]{1L, "qwe"}, + new Object[]{1L, "tyu"} + ) + ) + ) + ) + ) + .when(loadedSegmentDataProvider) + .fetchRowsFromDataServer(any(), any(), any(), any()); + + testSelectQuery() + .setSql("select cnt, dim1 from foo") + .setExpectedMSQSpec( + MSQSpec.builder() + .query( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("cnt", "dim1") + .context(defaultScanQueryContext(REALTIME_QUERY_CTX, resultSignature)) + .build() + ) + .columnMappings(ColumnMappings.identity(resultSignature)) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(TaskReportMSQDestination.INSTANCE) + .build() + ) + .setQueryContext(REALTIME_QUERY_CTX) + .setExpectedRowSignature(resultSignature) + .setExpectedResultRows(ImmutableList.of( + new Object[]{1L, ""}, + new Object[]{1L, "qwe"}, + new Object[]{1L, "10.1"}, + new Object[]{1L, "tyu"}, + new Object[]{1L, "2"}, + new Object[]{1L, "1"}, + new Object[]{1L, "def"}, + new Object[]{1L, "abc"} + )) + .verifyResults(); + } + + @Test + public void testGroupByWithLoadedSegmentsOnFoo() throws IOException + { + RowSignature rowSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("cnt1", ColumnType.LONG) + .build(); + + doReturn( + Pair.of(LoadedSegmentDataProvider.DataServerQueryStatus.SUCCESS, + Yielders.each( + Sequences.simple( + ImmutableList.of( + ResultRow.of(1L, 2L) + ) + ) + ) + ) + ) + .when(loadedSegmentDataProvider) + .fetchRowsFromDataServer(any(), any(), any(), any()); + + testSelectQuery() + .setSql("select cnt,count(*) as cnt1 from foo group by cnt") + .setExpectedMSQSpec( + MSQSpec.builder() + .query(GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration + .eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions( + new DefaultDimensionSpec( + "cnt", + "d0", + ColumnType.LONG + ) + )) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory( + "a0"))) + .setContext(REALTIME_QUERY_CTX) + .build()) + .columnMappings( + new ColumnMappings(ImmutableList.of( + new ColumnMapping("d0", "cnt"), + new ColumnMapping("a0", "cnt1"))) + ) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(TaskReportMSQDestination.INSTANCE) + .build() + ) + .setQueryContext(REALTIME_QUERY_CTX) + .setExpectedRowSignature(rowSignature) + .setExpectedResultRows(ImmutableList.of(new Object[]{1L, 8L})) + .verifyResults(); + } + + @Test + public void testGroupByWithOnlyLoadedSegmentsOnFoo() throws IOException + { + RowSignature rowSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("cnt1", ColumnType.LONG) + .build(); + + doReturn( + Pair.of(LoadedSegmentDataProvider.DataServerQueryStatus.SUCCESS, + Yielders.each( + Sequences.simple( + ImmutableList.of( + ResultRow.of(1L, 2L))))) + ).when(loadedSegmentDataProvider) + .fetchRowsFromDataServer(any(), any(), any(), any()); + + testSelectQuery() + .setSql("select cnt,count(*) as cnt1 from foo where (TIMESTAMP '2003-01-01 00:00:00' <= \"__time\" AND \"__time\" < TIMESTAMP '2005-01-01 00:00:00') group by cnt") + .setExpectedMSQSpec( + MSQSpec.builder() + .query(GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(Intervals.of("2003-01-01T00:00:00.000Z/2005-01-01T00:00:00.000Z")) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions( + new DefaultDimensionSpec( + "cnt", + "d0", + ColumnType.LONG + ) + )) + .setQuerySegmentSpec(new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2003-01-01T00:00:00.000Z/2005-01-01T00:00:00.000Z")))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory( + "a0"))) + .setContext(REALTIME_QUERY_CTX) + .build()) + .columnMappings( + new ColumnMappings(ImmutableList.of( + new ColumnMapping("d0", "cnt"), + new ColumnMapping("a0", "cnt1"))) + ) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(TaskReportMSQDestination.INSTANCE) + .build() + ) + .setQueryContext(REALTIME_QUERY_CTX) + .setExpectedRowSignature(rowSignature) + .setExpectedResultRows(ImmutableList.of(new Object[]{1L, 2L})) + .verifyResults(); + } + + @Test + public void testDataServerQueryFailedShouldFail() throws IOException + { + RowSignature rowSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("cnt1", ColumnType.LONG) + .build(); + + doThrow( + new ISE("Segment could not be found on data server, but segment was not handed off.") + ) + .when(loadedSegmentDataProvider) + .fetchRowsFromDataServer(any(), any(), any(), any()); + + testSelectQuery() + .setSql("select cnt,count(*) as cnt1 from foo where (TIMESTAMP '2003-01-01 00:00:00' <= \"__time\" AND \"__time\" < TIMESTAMP '2005-01-01 00:00:00') group by cnt") + .setExpectedMSQSpec( + MSQSpec.builder() + .query(GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(Intervals.of("2003-01-01T00:00:00.000Z/2005-01-01T00:00:00.000Z")) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions( + new DefaultDimensionSpec( + "cnt", + "d0", + ColumnType.LONG + ) + )) + .setQuerySegmentSpec(new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2003-01-01T00:00:00.000Z/2005-01-01T00:00:00.000Z")))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory( + "a0"))) + .setContext(REALTIME_QUERY_CTX) + .build()) + .columnMappings( + new ColumnMappings(ImmutableList.of( + new ColumnMapping("d0", "cnt"), + new ColumnMapping("a0", "cnt1"))) + ) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(TaskReportMSQDestination.INSTANCE) + .build() + ) + .setQueryContext(REALTIME_QUERY_CTX) + .setExpectedRowSignature(rowSignature) + .setExpectedExecutionErrorMatcher(CoreMatchers.instanceOf(ISE.class)) + .verifyExecutionError(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java index 0ea9ab45f482..2ae8d155d4dc 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java @@ -49,6 +49,7 @@ public void setup() injectorMock, null, null, + null, null ); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/RichSegmentDescriptorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/RichSegmentDescriptorTest.java index 8884a92a665b..935b464e0386 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/RichSegmentDescriptorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/RichSegmentDescriptorTest.java @@ -20,15 +20,28 @@ package org.apache.druid.msq.input.table; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableSet; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.segment.TestHelper; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; import org.junit.Assert; import org.junit.Test; public class RichSegmentDescriptorTest { + private static final DruidServerMetadata DRUID_SERVER_1 = new DruidServerMetadata( + "name1", + "host1", + null, + 100L, + ServerType.REALTIME, + "tier1", + 0 + ); + @Test public void testSerdeWithFullIntervalDifferentFromInterval() throws Exception { @@ -37,7 +50,8 @@ public void testSerdeWithFullIntervalDifferentFromInterval() throws Exception Intervals.of("2000/2002"), Intervals.of("2000/2001"), "2", - 3 + 3, + ImmutableSet.of(DRUID_SERVER_1) ); Assert.assertEquals( @@ -54,7 +68,8 @@ public void testSerdeWithFullIntervalSameAsInterval() throws Exception Intervals.of("2000/2001"), Intervals.of("2000/2001"), "2", - 3 + 3, + ImmutableSet.of(DRUID_SERVER_1) ); Assert.assertEquals( @@ -71,7 +86,8 @@ public void testDeserializeRichSegmentDescriptorAsSegmentDescriptor() throws Exc Intervals.of("2000/2002"), Intervals.of("2000/2001"), "2", - 3 + 3, + ImmutableSet.of(DRUID_SERVER_1) ); Assert.assertEquals( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentWithDescriptorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentWithDescriptorTest.java index 29a0ebef4ba4..875bef371e90 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentWithDescriptorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentWithDescriptorTest.java @@ -19,6 +19,7 @@ package org.apache.druid.msq.input.table; +import com.fasterxml.jackson.databind.ObjectMapper; import nl.jqno.equalsverifier.EqualsVerifier; import org.junit.Test; @@ -27,6 +28,10 @@ public class SegmentWithDescriptorTest @Test public void testEquals() { - EqualsVerifier.forClass(SegmentWithDescriptor.class).usingGetClass().verify(); + EqualsVerifier.forClass(SegmentWithDescriptor.class) + .withPrefabValues(ObjectMapper.class, new ObjectMapper(), new ObjectMapper()) + .withIgnoredFields("loadedSegmentDataProvider") + .usingGetClass() + .verify(); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentsInputSliceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentsInputSliceTest.java index df2937f30036..55bb424512d8 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentsInputSliceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/SegmentsInputSliceTest.java @@ -21,11 +21,14 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.segment.TestHelper; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; import org.junit.Assert; import org.junit.Test; @@ -44,7 +47,18 @@ public void testSerde() throws Exception Intervals.of("2000/P1M"), Intervals.of("2000/P1M"), "1", - 0 + 0, + ImmutableSet.of( + new DruidServerMetadata( + "name1", + "host1", + null, + 100L, + ServerType.REALTIME, + "tier1", + 0 + ) + ) ) ) ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java index dbcb3646e887..fd5db7e75f64 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java @@ -146,25 +146,29 @@ public void test_sliceStatic_intervalFilter() SEGMENT1.getInterval(), Intervals.of("2000/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ), new RichSegmentDescriptor( SEGMENT2.getInterval(), Intervals.of("2000/P1M"), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ), new RichSegmentDescriptor( SEGMENT1.getInterval(), Intervals.of("2000-06-01/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ), new RichSegmentDescriptor( SEGMENT2.getInterval(), Intervals.of("2000-06-01/P1M"), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) @@ -206,7 +210,8 @@ public void test_sliceStatic_dimFilter() SEGMENT1.getInterval(), SEGMENT1.getInterval(), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ), @@ -237,7 +242,8 @@ public void test_sliceStatic_intervalAndDimFilter() SEGMENT1.getInterval(), Intervals.of("2000/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ), @@ -248,7 +254,8 @@ public void test_sliceStatic_intervalAndDimFilter() SEGMENT1.getInterval(), Intervals.of("2000-06-01/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ) @@ -270,13 +277,15 @@ public void test_sliceStatic_oneSlice() SEGMENT1.getInterval(), SEGMENT1.getInterval(), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ), new RichSegmentDescriptor( SEGMENT2.getInterval(), SEGMENT2.getInterval(), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) @@ -298,7 +307,8 @@ public void test_sliceStatic_needTwoSlices() SEGMENT1.getInterval(), SEGMENT1.getInterval(), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ), @@ -309,7 +319,8 @@ public void test_sliceStatic_needTwoSlices() SEGMENT2.getInterval(), SEGMENT2.getInterval(), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) @@ -331,7 +342,8 @@ public void test_sliceStatic_threeSlices() SEGMENT1.getInterval(), SEGMENT1.getInterval(), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ), @@ -342,7 +354,8 @@ public void test_sliceStatic_threeSlices() SEGMENT2.getInterval(), SEGMENT2.getInterval(), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ), @@ -385,13 +398,15 @@ public void test_sliceDynamic_maxOneSlice() SEGMENT1.getInterval(), Intervals.of("2000/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ), new RichSegmentDescriptor( SEGMENT2.getInterval(), Intervals.of("2000/P1M"), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) @@ -418,13 +433,15 @@ public void test_sliceDynamic_needOne() SEGMENT1.getInterval(), Intervals.of("2000/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ), new RichSegmentDescriptor( SEGMENT2.getInterval(), Intervals.of("2000/P1M"), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) @@ -451,7 +468,8 @@ public void test_sliceDynamic_needTwoDueToFiles() SEGMENT1.getInterval(), Intervals.of("2000/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ), @@ -462,7 +480,8 @@ public void test_sliceDynamic_needTwoDueToFiles() SEGMENT2.getInterval(), Intervals.of("2000/P1M"), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) @@ -489,7 +508,8 @@ public void test_sliceDynamic_needTwoDueToBytes() SEGMENT1.getInterval(), Intervals.of("2000/P1M"), SEGMENT1.getVersion(), - SEGMENT1.getShardSpec().getPartitionNum() + SEGMENT1.getShardSpec().getPartitionNum(), + null ) ) ), @@ -500,7 +520,8 @@ public void test_sliceDynamic_needTwoDueToBytes() SEGMENT2.getInterval(), Intervals.of("2000/P1M"), SEGMENT2.getVersion(), - SEGMENT2.getShardSpec().getPartitionNum() + SEGMENT2.getShardSpec().getPartitionNum(), + null ) ) ) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java index 82301f4ddfed..abefe6a378d8 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.test; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import com.google.inject.Injector; import com.google.inject.Module; import org.apache.druid.guice.DruidInjectorBuilder; @@ -80,7 +81,8 @@ public SqlEngine createEngine( queryJsonMapper, injector, new MSQTestTaskActionClient(queryJsonMapper), - workerMemoryParameters + workerMemoryParameters, + ImmutableList.of() ); return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java index d7c0ea1f2d5f..c68b2331c7d9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java @@ -40,6 +40,8 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.msq.exec.LoadedSegmentDataProvider; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.guice.MSQExternalDataSourceModule; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.querykit.DataSegmentProvider; @@ -78,6 +80,7 @@ import org.easymock.EasyMock; import org.joda.time.Interval; import org.junit.rules.TemporaryFolder; +import org.mockito.Mockito; import javax.annotation.Nullable; import java.io.File; @@ -96,6 +99,10 @@ import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS1_WITH_NUMERIC_DIMS; import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS2; import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS_LOTS_OF_COLUMNS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; /** * Helper class aiding in wiring up the Guice bindings required for MSQ engine to work with the Calcite's tests @@ -165,6 +172,7 @@ public String getFormatString() binder.bind(DataSegmentAnnouncer.class).toInstance(new NoopDataSegmentAnnouncer()); binder.bind(DataSegmentProvider.class) .toInstance((segmentId, channelCounters, isReindex) -> getSupplierForSegment(segmentId)); + binder.bind(LoadedSegmentDataProviderFactory.class).toInstance(getTestLoadedSegmentDataProviderFactory()); GroupByQueryConfig groupByQueryConfig = new GroupByQueryConfig(); GroupingEngine groupingEngine = GroupByQueryRunnerTest.makeQueryRunnerFactory( @@ -182,6 +190,24 @@ public String getFormatString() ); } + private static LoadedSegmentDataProviderFactory getTestLoadedSegmentDataProviderFactory() + { + // Currently, there is no metadata in this test for loaded segments. Therefore, this should not be called. + // In the future, if this needs to be supported, mocks for LoadedSegmentDataProvider should be added like + // org.apache.druid.msq.exec.MSQLoadedSegmentTests. + LoadedSegmentDataProviderFactory mockFactory = Mockito.mock(LoadedSegmentDataProviderFactory.class); + LoadedSegmentDataProvider loadedSegmentDataProvider = Mockito.mock(LoadedSegmentDataProvider.class); + try { + doThrow(new AssertionError("Test does not support loaded segment query")) + .when(loadedSegmentDataProvider).fetchRowsFromDataServer(any(), any(), any(), any()); + doReturn(loadedSegmentDataProvider).when(mockFactory).createLoadedSegmentDataProvider(anyString(), any()); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return mockFactory; + } + private static Supplier> getSupplierForSegment(SegmentId segmentId) { final TemporaryFolder temporaryFolder = new TemporaryFolder(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java index ab7b1ed7d7cf..114583d31a1a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.test; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import com.google.inject.Injector; import com.google.inject.Module; import org.apache.calcite.rel.RelRoot; @@ -128,7 +129,8 @@ public SqlEngine createEngine( queryJsonMapper, injector, new MSQTestTaskActionClient(queryJsonMapper), - workerMemoryParameters + workerMemoryParameters, + ImmutableList.of() ); return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper) { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java index 5ee3ba875388..e9c54cfc54b4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java @@ -85,7 +85,8 @@ public SqlEngine createEngine( queryJsonMapper, injector, new MSQTestTaskActionClient(queryJsonMapper), - workerMemoryParameters + workerMemoryParameters, + ImmutableList.of() ); return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 1146bb1c9d16..4d97b911a091 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -36,6 +36,7 @@ import com.google.inject.TypeLiteral; import com.google.inject.util.Modules; import com.google.inject.util.Providers; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.collections.ReferenceCountingResourceHolder; import org.apache.druid.collections.ResourceHolder; import org.apache.druid.common.config.NullHandling; @@ -83,6 +84,8 @@ import org.apache.druid.msq.counters.QueryCounterSnapshot; import org.apache.druid.msq.exec.ClusterStatisticsMergeMode; import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.LoadedSegmentDataProvider; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.guice.MSQDurableStorageModule; import org.apache.druid.msq.guice.MSQExternalDataSourceModule; @@ -293,6 +296,11 @@ public class MSQTestBase extends BaseCalciteQueryTest protected AuthorizerMapper authorizerMapper; private IndexIO indexIO; + // Contains the metadata of loaded segments + protected List loadedSegmentsMetadata = new ArrayList<>(); + // Mocks the return of data from data servers + protected LoadedSegmentDataProvider loadedSegmentDataProvider = mock(LoadedSegmentDataProvider.class); + private MSQTestSegmentManager segmentManager; private SegmentCacheManager segmentCacheManager; @Rule @@ -416,7 +424,8 @@ public String getFormatString() binder.bind(QueryProcessingPool.class) .toInstance(new ForwardingQueryProcessingPool(Execs.singleThreaded("Test-runner-processing-pool"))); binder.bind(DataSegmentProvider.class) - .toInstance((dataSegment, channelCounters, isReindex) -> getSupplierForSegment(dataSegment)); + .toInstance((segmentId, channelCounters, isReindex) -> getSupplierForSegment(segmentId)); + binder.bind(LoadedSegmentDataProviderFactory.class).toInstance(getTestLoadedSegmentDataProviderFactory()); binder.bind(IndexIO.class).toInstance(indexIO); binder.bind(SpecificSegmentsQuerySegmentWalker.class).toInstance(qf.walker()); @@ -497,7 +506,8 @@ public String getFormatString() objectMapper, injector, testTaskActionClient, - workerMemoryParameters + workerMemoryParameters, + loadedSegmentsMetadata ); CatalogResolver catalogResolver = createMockCatalogResolver(); final InProcessViewManager viewManager = new InProcessViewManager(SqlTestFramework.DRUID_VIEW_MACRO_FACTORY); @@ -570,6 +580,15 @@ protected long[] createExpectedFrameArray(int length, int value) return array; } + private LoadedSegmentDataProviderFactory getTestLoadedSegmentDataProviderFactory() + { + LoadedSegmentDataProviderFactory mockFactory = Mockito.mock(LoadedSegmentDataProviderFactory.class); + doReturn(loadedSegmentDataProvider) + .when(mockFactory) + .createLoadedSegmentDataProvider(anyString(), any()); + return mockFactory; + } + @Nonnull private Supplier> getSupplierForSegment(SegmentId segmentId) { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 027d2a913b21..c62be112eed4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -26,6 +26,7 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.inject.Injector; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.coordinator.CoordinatorClient; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; @@ -56,6 +57,7 @@ import javax.annotation.Nullable; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -94,7 +96,8 @@ public MSQTestControllerContext( ObjectMapper mapper, Injector injector, TaskActionClient taskActionClient, - WorkerMemoryParameters workerMemoryParameters + WorkerMemoryParameters workerMemoryParameters, + List loadedSegments ) { this.mapper = mapper; @@ -115,6 +118,18 @@ public MSQTestControllerContext( .collect(Collectors.toList()) ) ); + + Mockito.when(coordinatorClient.fetchServerViewSegments( + ArgumentMatchers.anyString(), + ArgumentMatchers.any() + ) + ).thenAnswer(invocation -> loadedSegments.stream() + .filter(immutableSegmentLoadInfo -> + immutableSegmentLoadInfo.getSegment() + .getDataSource() + .equals(invocation.getArguments()[0])) + .collect(Collectors.toList()) + ); this.workerMemoryParameters = workerMemoryParameters; } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java index 1b49982cad46..c5f601d875ef 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java @@ -25,6 +25,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Injector; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.indexing.NoopOverlordClient; import org.apache.druid.client.indexing.TaskPayloadResponse; import org.apache.druid.client.indexing.TaskStatusResponse; @@ -43,6 +44,7 @@ import javax.annotation.Nullable; import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; public class MSQTestOverlordServiceClient extends NoopOverlordClient @@ -51,10 +53,11 @@ public class MSQTestOverlordServiceClient extends NoopOverlordClient private final ObjectMapper objectMapper; private final TaskActionClient taskActionClient; private final WorkerMemoryParameters workerMemoryParameters; - private Map inMemoryControllers = new HashMap<>(); - private Map> reports = new HashMap<>(); - private Map inMemoryControllerTask = new HashMap<>(); - private Map inMemoryTaskStatus = new HashMap<>(); + private final List loadedSegmentMetadata; + private final Map inMemoryControllers = new HashMap<>(); + private final Map> reports = new HashMap<>(); + private final Map inMemoryControllerTask = new HashMap<>(); + private final Map inMemoryTaskStatus = new HashMap<>(); public static final DateTime CREATED_TIME = DateTimes.of("2023-05-31T12:00Z"); public static final DateTime QUEUE_INSERTION_TIME = DateTimes.of("2023-05-31T12:01Z"); @@ -65,13 +68,15 @@ public MSQTestOverlordServiceClient( ObjectMapper objectMapper, Injector injector, TaskActionClient taskActionClient, - WorkerMemoryParameters workerMemoryParameters + WorkerMemoryParameters workerMemoryParameters, + List loadedSegmentMetadata ) { this.objectMapper = objectMapper; this.injector = injector; this.taskActionClient = taskActionClient; this.workerMemoryParameters = workerMemoryParameters; + this.loadedSegmentMetadata = loadedSegmentMetadata; } @Override @@ -84,7 +89,8 @@ public ListenableFuture runTask(String taskId, Object taskObject) objectMapper, injector, taskActionClient, - workerMemoryParameters + workerMemoryParameters, + loadedSegmentMetadata ); MSQControllerTask cTask = objectMapper.convertValue(taskObject, MSQControllerTask.class); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index a478d1c3c171..51d83397ccae 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -29,6 +29,7 @@ import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.exec.LoadedSegmentDataProviderFactory; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; import org.apache.druid.msq.exec.WorkerContext; @@ -154,10 +155,12 @@ public void setObjectMapper(ObjectMapper objectMapper) injector, indexIO, null, + null, null ), indexIO, injector.getInstance(DataSegmentProvider.class), + injector.getInstance(LoadedSegmentDataProviderFactory.class), workerMemoryParameters ); } @@ -179,4 +182,10 @@ public Bouncer processorBouncer() { return injector.getInstance(Bouncer.class); } + + @Override + public LoadedSegmentDataProviderFactory loadedSegmentDataProviderFactory() + { + return injector.getInstance(LoadedSegmentDataProviderFactory.class); + } } diff --git a/processing/src/main/java/org/apache/druid/query/IterableRowsCursorHelper.java b/processing/src/main/java/org/apache/druid/query/IterableRowsCursorHelper.java index b4d06edc77cf..4bf1cb92a610 100644 --- a/processing/src/main/java/org/apache/druid/query/IterableRowsCursorHelper.java +++ b/processing/src/main/java/org/apache/druid/query/IterableRowsCursorHelper.java @@ -24,6 +24,7 @@ import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.segment.Cursor; import org.apache.druid.segment.RowAdapter; import org.apache.druid.segment.RowBasedCursor; @@ -32,6 +33,7 @@ import org.apache.druid.segment.column.RowSignature; import java.io.Closeable; +import java.util.Iterator; /** * Helper methods to create cursor from iterable of rows @@ -82,4 +84,35 @@ public static Pair getCursorFromSequence(Sequence r return Pair.of(baseCursor, rowWalker); } + + public static Pair getCursorFromYielder(Yielder yielderParam, RowSignature rowSignature) + { + return getCursorFromIterable( + new Iterable() + { + Yielder yielder = yielderParam; + @Override + public Iterator iterator() + { + return new Iterator() + { + @Override + public boolean hasNext() + { + return !yielder.isDone(); + } + + @Override + public Object[] next() + { + Object[] retVal = yielder.get(); + yielder = yielder.next(null); + return retVal; + } + }; + } + }, + rowSignature + ); + } } diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java index b79c4358a3de..b242ff98555a 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java @@ -177,40 +177,8 @@ public BinaryOperator createMergeFn(Query queryParam) return new GroupByBinaryFnV2((GroupByQuery) queryParam); } - /** - * Runs a provided {@link QueryRunner} on a provided {@link GroupByQuery}, which is assumed to return rows that are - * properly sorted (by timestamp and dimensions) but not necessarily fully merged (that is, there may be adjacent - * rows with the same timestamp and dimensions) and without PostAggregators computed. This method will fully merge - * the rows, apply PostAggregators, and return the resulting {@link Sequence}. - * - * The query will be modified before passing it down to the base runner. For example, "having" clauses will be - * removed and various context parameters will be adjusted. - * - * Despite the similar name, this method is much reduced in scope compared to - * {@link GroupByQueryQueryToolChest#mergeResults(QueryRunner)}. That method does delegate to this one at some points, - * but has a truckload of other responsibility, including computing outer query results (if there are subqueries), - * computing subtotals (like GROUPING SETS), and computing the havingSpec and limitSpec. - * - * @param baseRunner base query runner - * @param query the groupBy query to run inside the base query runner - * @param responseContext the response context to pass to the base query runner - * - * @return merged result sequence - */ - public Sequence mergeResults( - final QueryRunner baseRunner, - final GroupByQuery query, - final ResponseContext responseContext - ) + public GroupByQuery prepareGroupByQuery(GroupByQuery query) { - // Merge streams using ResultMergeQueryRunner, then apply postaggregators, then apply limit (which may - // involve materialization) - final ResultMergeQueryRunner mergingQueryRunner = new ResultMergeQueryRunner<>( - baseRunner, - this::createResultComparator, - this::createMergeFn - ); - // Set up downstream context. final ImmutableMap.Builder context = ImmutableMap.builder(); context.put(QueryContexts.FINALIZE_KEY, false); @@ -224,7 +192,6 @@ public Sequence mergeResults( final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty()) && queryContext.getBoolean(CTX_KEY_OUTERMOST, true) && !query.isApplyLimitPushDown(); - int timestampResultFieldIndex = 0; if (hasTimestampResultField) { // sql like "group by city_id,time_floor(__time to day)", // the original translated query is granularity=all and dimensions:[d0, d1] @@ -257,7 +224,7 @@ public Sequence mergeResults( granularity = timestampResultFieldGranularity; // when timestampResultField is the last dimension, should set sortByDimsFirst=true, // otherwise the downstream is sorted by row's timestamp first which makes the final ordering not as expected - timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); + int timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX, 0); if (!query.getContextSortByDimsFirst() && timestampResultFieldIndex == query.getDimensions().size() - 1) { context.put(GroupByQuery.CTX_KEY_SORT_BY_DIMS_FIRST, true); } @@ -269,7 +236,6 @@ public Sequence mergeResults( // when hasTimestampResultField=true and timestampResultField is neither first nor last dimension, // the DefaultLimitSpec will always do the reordering } - final int timestampResultFieldIndexInOriginalDimensions = timestampResultFieldIndex; if (query.getUniversalTimestamp() != null && !hasTimestampResultField) { // universalTimestamp works only when granularity is all // hasTimestampResultField works only when granularity is all @@ -283,7 +249,7 @@ public Sequence mergeResults( // Always request array result rows when passing the query downstream. context.put(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, true); - final GroupByQuery newQuery = new GroupByQuery( + return new GroupByQuery( query.getDataSource(), query.getQuerySegmentSpec(), query.getVirtualColumns(), @@ -305,6 +271,49 @@ public Sequence mergeResults( ).withOverriddenContext( context.build() ); + } + + /** + * Runs a provided {@link QueryRunner} on a provided {@link GroupByQuery}, which is assumed to return rows that are + * properly sorted (by timestamp and dimensions) but not necessarily fully merged (that is, there may be adjacent + * rows with the same timestamp and dimensions) and without PostAggregators computed. This method will fully merge + * the rows, apply PostAggregators, and return the resulting {@link Sequence}. + * + * The query will be modified using {@link #prepareGroupByQuery(GroupByQuery)} before passing it down to the base + * runner. For example, "having" clauses will be removed and various context parameters will be adjusted. + * + * Despite the similar name, this method is much reduced in scope compared to + * {@link GroupByQueryQueryToolChest#mergeResults(QueryRunner)}. That method does delegate to this one at some points, + * but has a truckload of other responsibility, including computing outer query results (if there are subqueries), + * computing subtotals (like GROUPING SETS), and computing the havingSpec and limitSpec. + * + * @param baseRunner base query runner + * @param query the groupBy query to run inside the base query runner + * @param responseContext the response context to pass to the base query runner + * + * @return merged result sequence + */ + public Sequence mergeResults( + final QueryRunner baseRunner, + final GroupByQuery query, + final ResponseContext responseContext + ) + { + // Merge streams using ResultMergeQueryRunner, then apply postaggregators, then apply limit (which may + // involve materialization) + final ResultMergeQueryRunner mergingQueryRunner = new ResultMergeQueryRunner<>( + baseRunner, + this::createResultComparator, + this::createMergeFn + ); + + final QueryContext queryContext = query.context(); + final String timestampResultField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); + final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty()) + && queryContext.getBoolean(CTX_KEY_OUTERMOST, true) + && !query.isApplyLimitPushDown(); + final int timestampResultFieldIndexInOriginalDimensions = hasTimestampResultField ? queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX) : 0; + final GroupByQuery newQuery = prepareGroupByQuery(query); final Sequence mergedResults = mergingQueryRunner.run(QueryPlus.wrap(newQuery), responseContext); diff --git a/processing/src/test/java/org/apache/druid/query/IterableRowsCursorHelperTest.java b/processing/src/test/java/org/apache/druid/query/IterableRowsCursorHelperTest.java index 45f14b80976c..7628c3289dd1 100644 --- a/processing/src/test/java/org/apache/druid/query/IterableRowsCursorHelperTest.java +++ b/processing/src/test/java/org/apache/druid/query/IterableRowsCursorHelperTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.Cursor; import org.apache.druid.segment.column.ColumnType; @@ -60,6 +61,13 @@ public void getCursorFromSequence() testCursorMatchesRowSequence(cursor, rowSignature, rows); } + @Test + public void getCursorFromYielder() + { + Cursor cursor = IterableRowsCursorHelper.getCursorFromYielder(Yielders.each(Sequences.simple(rows)), rowSignature).lhs; + testCursorMatchesRowSequence(cursor, rowSignature, rows); + } + private void testCursorMatchesRowSequence( Cursor cursor, RowSignature expectedRowSignature, diff --git a/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClient.java b/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClient.java index 08110f61f059..336576b675dc 100644 --- a/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClient.java +++ b/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClient.java @@ -20,6 +20,7 @@ package org.apache.druid.client.coordinator; import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.rpc.ServiceRetryPolicy; import org.apache.druid.timeline.DataSegment; @@ -40,6 +41,11 @@ public interface CoordinatorClient */ ListenableFuture fetchSegment(String dataSource, String segmentId, boolean includeUnused); + /** + * Fetches segments from the coordinator server view for the given dataSource and intervals. + */ + Iterable fetchServerViewSegments(String dataSource, List intervals); + /** * Fetches segment metadata for the given dataSource and intervals. */ diff --git a/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClientImpl.java b/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClientImpl.java index e93cbe830b3f..f82beb2778d1 100644 --- a/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClientImpl.java +++ b/server/src/main/java/org/apache/druid/client/coordinator/CoordinatorClientImpl.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.jackson.JacksonUtils; @@ -34,6 +35,7 @@ import org.jboss.netty.handler.codec.http.HttpMethod; import org.joda.time.Interval; +import java.util.ArrayList; import java.util.List; public class CoordinatorClientImpl implements CoordinatorClient @@ -89,6 +91,37 @@ public ListenableFuture fetchSegment(String dataSource, String segm ); } + @Override + public Iterable fetchServerViewSegments(String dataSource, List intervals) + { + ArrayList retVal = new ArrayList<>(); + for (Interval interval : intervals) { + String intervalString = StringUtils.replace(interval.toString(), "/", "_"); + + final String path = StringUtils.format( + "/druid/coordinator/v1/datasources/%s/intervals/%s/serverview?full", + StringUtils.urlEncode(dataSource), + intervalString + ); + ListenableFuture> segments = FutureUtils.transform( + client.asyncRequest( + new RequestBuilder(HttpMethod.GET, path), + new BytesFullResponseHandler() + ), + holder -> JacksonUtils.readValue( + jsonMapper, + holder.getContent(), + new TypeReference>() + { + } + ) + ); + FutureUtils.getUnchecked(segments, true).forEach(retVal::add); + } + + return retVal; + } + @Override public ListenableFuture> fetchUsedSegments(String dataSource, List intervals) { diff --git a/server/src/main/java/org/apache/druid/discovery/DataServerClient.java b/server/src/main/java/org/apache/druid/discovery/DataServerClient.java new file mode 100644 index 000000000000..479ba9d4142d --- /dev/null +++ b/server/src/main/java/org/apache/druid/discovery/DataServerClient.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.discovery; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.client.JsonParserIterator; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.BaseSequence; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.response.StatusResponseHandler; +import org.apache.druid.java.util.http.client.response.StatusResponseHolder; +import org.apache.druid.query.Query; +import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.rpc.FixedSetServiceLocator; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.apache.druid.utils.CloseableUtils; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import java.io.InputStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Client to query data servers given a query. + */ +public class DataServerClient +{ + private static final String BASE_PATH = "/druid/v2/"; + private static final Logger log = new Logger(DataServerClient.class); + private final ServiceClient serviceClient; + private final ObjectMapper objectMapper; + private final ServiceLocation serviceLocation; + private final ScheduledExecutorService queryCancellationExecutor; + + public DataServerClient( + ServiceClientFactory serviceClientFactory, + ServiceLocation serviceLocation, + ObjectMapper objectMapper, + ScheduledExecutorService queryCancellationExecutor + ) + { + this.serviceClient = serviceClientFactory.makeClient( + serviceLocation.getHost(), + FixedSetServiceLocator.forServiceLocation(serviceLocation), + StandardRetryPolicy.noRetries() + ); + this.serviceLocation = serviceLocation; + this.objectMapper = objectMapper; + this.queryCancellationExecutor = queryCancellationExecutor; + } + + public Sequence run(Query query, ResponseContext responseContext, JavaType queryResultType, Closer closer) + { + final String cancelPath = BASE_PATH + query.getId(); + + RequestBuilder requestBuilder = new RequestBuilder(HttpMethod.POST, BASE_PATH); + final boolean isSmile = objectMapper.getFactory() instanceof SmileFactory; + if (isSmile) { + requestBuilder = requestBuilder.smileContent(objectMapper, query); + } else { + requestBuilder = requestBuilder.jsonContent(objectMapper, query); + } + + log.debug("Sending request to servers for query[%s], request[%s]", query.getId(), requestBuilder.toString()); + ListenableFuture resultStreamFuture = serviceClient.asyncRequest( + requestBuilder, + new DataServerResponseHandler(query, responseContext, objectMapper) + ); + + closer.register(() -> resultStreamFuture.cancel(true)); + Futures.addCallback( + resultStreamFuture, + new FutureCallback() + { + @Override + public void onSuccess(InputStream result) + { + // Do nothing + } + + @Override + public void onFailure(Throwable t) + { + if (resultStreamFuture.isCancelled()) { + cancelQuery(query, cancelPath); + } + } + }, + Execs.directExecutor() + ); + + return new BaseSequence<>( + new BaseSequence.IteratorMaker>() + { + @Override + public JsonParserIterator make() + { + return new JsonParserIterator<>( + queryResultType, + resultStreamFuture, + BASE_PATH, + query, + serviceLocation.getHost(), + objectMapper + ); + } + + @Override + public void cleanup(JsonParserIterator iterFromMake) + { + CloseableUtils.closeAndWrapExceptions(iterFromMake); + } + } + ); + } + + private void cancelQuery(Query query, String cancelPath) + { + Runnable cancelRunnable = () -> { + Future cancelFuture = serviceClient.asyncRequest( + new RequestBuilder(HttpMethod.DELETE, cancelPath), + StatusResponseHandler.getInstance()); + + Runnable checkRunnable = () -> { + try { + if (!cancelFuture.isDone()) { + log.error("Error cancelling query[%s]", query); + } + StatusResponseHolder response = cancelFuture.get(); + if (response.getStatus().getCode() >= 500) { + log.error("Error cancelling query[%s]: queryable node returned status[%d] [%s].", + query, + response.getStatus().getCode(), + response.getStatus().getReasonPhrase()); + } + } + catch (ExecutionException | InterruptedException e) { + log.error(e, "Error cancelling query[%s]", query); + } + }; + queryCancellationExecutor.schedule(checkRunnable, 5, TimeUnit.SECONDS); + }; + queryCancellationExecutor.submit(cancelRunnable); + } +} diff --git a/server/src/main/java/org/apache/druid/discovery/DataServerResponseHandler.java b/server/src/main/java/org/apache/druid/discovery/DataServerResponseHandler.java new file mode 100644 index 000000000000..7715000f8209 --- /dev/null +++ b/server/src/main/java/org/apache/druid/discovery/DataServerResponseHandler.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.discovery; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.io.AppendableByteArrayInputStream; +import org.apache.druid.java.util.http.client.response.ClientResponse; +import org.apache.druid.java.util.http.client.response.HttpResponseHandler; +import org.apache.druid.query.Query; +import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.server.QueryResource; +import org.jboss.netty.buffer.ChannelBuffer; +import org.jboss.netty.handler.codec.http.HttpChunk; +import org.jboss.netty.handler.codec.http.HttpResponse; + +import java.io.IOException; +import java.io.InputStream; + +/** + * Response handler for the {@link DataServerClient}. Handles the input stream from the data server and handles updating + * the {@link ResponseContext} from the header. Does not apply backpressure or query timeout. + */ +public class DataServerResponseHandler implements HttpResponseHandler +{ + private static final Logger log = new Logger(DataServerResponseHandler.class); + private final String queryId; + private final ResponseContext responseContext; + private final ObjectMapper objectMapper; + + public DataServerResponseHandler(Query query, ResponseContext responseContext, ObjectMapper objectMapper) + { + this.queryId = query.getId(); + this.responseContext = responseContext; + this.objectMapper = objectMapper; + } + + @Override + public ClientResponse handleResponse(HttpResponse response, TrafficCop trafficCop) + { + log.debug("Received response status[%s] for queryId[%s]", response.getStatus(), queryId); + AppendableByteArrayInputStream in = new AppendableByteArrayInputStream(); + in.add(getContentBytes(response.getContent())); + + try { + final String queryResponseHeaders = response.headers().get(QueryResource.HEADER_RESPONSE_CONTEXT); + if (queryResponseHeaders != null) { + responseContext.merge(ResponseContext.deserialize(queryResponseHeaders, objectMapper)); + } + return ClientResponse.finished(in); + } + catch (IOException e) { + return ClientResponse.finished( + new AppendableByteArrayInputStream() + { + @Override + public int read() throws IOException + { + throw e; + } + } + ); + } + } + + @Override + public ClientResponse handleChunk( + ClientResponse clientResponse, + HttpChunk chunk, + long chunkNum + ) + { + clientResponse.getObj().add(getContentBytes(chunk.getContent())); + return clientResponse; + } + + @Override + public ClientResponse done(ClientResponse clientResponse) + { + final AppendableByteArrayInputStream obj = clientResponse.getObj(); + obj.done(); + return ClientResponse.finished(obj); + } + + @Override + public void exceptionCaught(ClientResponse clientResponse, Throwable e) + { + final AppendableByteArrayInputStream obj = clientResponse.getObj(); + obj.exceptionCaught(e); + } + + private byte[] getContentBytes(ChannelBuffer content) + { + byte[] contentBytes = new byte[content.readableBytes()]; + content.readBytes(contentBytes); + return contentBytes; + } +} diff --git a/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java new file mode 100644 index 000000000000..f4bfa18470d2 --- /dev/null +++ b/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.rpc; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.jboss.netty.util.internal.ThreadLocalRandom; + +import javax.validation.constraints.NotNull; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Basic implmentation of {@link ServiceLocator} that returns a service location from a static set of locations. Returns + * a random location each time one is requested. + */ +public class FixedSetServiceLocator implements ServiceLocator +{ + private ServiceLocations serviceLocations; + + private FixedSetServiceLocator(ServiceLocations serviceLocations) + { + this.serviceLocations = serviceLocations; + } + + public static FixedSetServiceLocator forServiceLocation(@NotNull ServiceLocation serviceLocation) + { + return new FixedSetServiceLocator(ServiceLocations.forLocation(serviceLocation)); + } + + public static FixedSetServiceLocator forDruidServerMetadata(Set serverMetadataSet) + { + if (serverMetadataSet == null || serverMetadataSet.isEmpty()) { + return new FixedSetServiceLocator(ServiceLocations.closed()); + } else { + Set serviceLocationSet = serverMetadataSet.stream() + .map(ServiceLocation::fromDruidServerMetadata) + .collect(Collectors.toSet()); + + return new FixedSetServiceLocator(ServiceLocations.forLocations(serviceLocationSet)); + } + } + + @Override + public ListenableFuture locate() + { + if (serviceLocations.isClosed() || serviceLocations.getLocations().isEmpty()) { + return Futures.immediateFuture(ServiceLocations.closed()); + } + + Set locationSet = serviceLocations.getLocations(); + return Futures.immediateFuture( + ServiceLocations.forLocation( + locationSet.stream() + .skip(ThreadLocalRandom.current().nextInt(locationSet.size())) + .findFirst() + .orElse(null) + ) + ); + } + + @Override + public void close() + { + serviceLocations = ServiceLocations.closed(); + } +} diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java index eab82df328a1..3a092d7cb8dd 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java @@ -20,8 +20,14 @@ package org.apache.druid.rpc; import com.google.common.base.Preconditions; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import org.apache.druid.java.util.common.ISE; import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; +import javax.validation.constraints.NotNull; +import java.util.Iterator; import java.util.Objects; /** @@ -47,6 +53,44 @@ public static ServiceLocation fromDruidNode(final DruidNode druidNode) return new ServiceLocation(druidNode.getHost(), druidNode.getPlaintextPort(), druidNode.getTlsPort(), ""); } + private static final Splitter SPLITTER = Splitter.on(":").limit(2); + + public static ServiceLocation fromDruidServerMetadata(final DruidServerMetadata druidServerMetadata) + { + final String host = getHostFromString( + Preconditions.checkNotNull( + druidServerMetadata.getHost(), + "Host was null for druid server metadata[%s]", + druidServerMetadata + ) + ); + int plaintextPort = getPortFromString(druidServerMetadata.getHostAndPort()); + int tlsPort = getPortFromString(druidServerMetadata.getHostAndTlsPort()); + return new ServiceLocation(host, plaintextPort, tlsPort, ""); + } + + private static String getHostFromString(@NotNull String s) + { + Iterator iterator = SPLITTER.split(s).iterator(); + ImmutableList strings = ImmutableList.copyOf(iterator); + return strings.get(0); + } + + private static int getPortFromString(String s) + { + if (s == null) { + return -1; + } + Iterator iterator = SPLITTER.split(s).iterator(); + ImmutableList strings = ImmutableList.copyOf(iterator); + try { + return Integer.parseInt(strings.get(1)); + } + catch (NumberFormatException e) { + throw new ISE(e, "Unable to parse port out of %s", strings.get(1)); + } + } + public String getHost() { return host; diff --git a/server/src/main/java/org/apache/druid/server/coordination/DruidServerMetadata.java b/server/src/main/java/org/apache/druid/server/coordination/DruidServerMetadata.java index 3fda41b08dab..fcb08d26a2bf 100644 --- a/server/src/main/java/org/apache/druid/server/coordination/DruidServerMetadata.java +++ b/server/src/main/java/org/apache/druid/server/coordination/DruidServerMetadata.java @@ -31,6 +31,7 @@ public class DruidServerMetadata { private final String name; + @Nullable private final String hostAndPort; @Nullable private final String hostAndTlsPort; @@ -39,10 +40,11 @@ public class DruidServerMetadata private final ServerType type; private final int priority; + // Either hostAndPort or hostAndTlsPort would be null depending on the type of connection. @JsonCreator public DruidServerMetadata( @JsonProperty("name") String name, - @JsonProperty("host") String hostAndPort, + @JsonProperty("host") @Nullable String hostAndPort, @JsonProperty("hostAndTlsPort") @Nullable String hostAndTlsPort, @JsonProperty("maxSize") long maxSize, @JsonProperty("type") ServerType type, @@ -70,6 +72,7 @@ public String getHost() return getHostAndTlsPort() != null ? getHostAndTlsPort() : getHostAndPort(); } + @Nullable @JsonProperty("host") public String getHostAndPort() { diff --git a/server/src/test/java/org/apache/druid/client/coordinator/CoordinatorClientImplTest.java b/server/src/test/java/org/apache/druid/client/coordinator/CoordinatorClientImplTest.java index f48e21327a0b..8977d64ee555 100644 --- a/server/src/test/java/org/apache/druid/client/coordinator/CoordinatorClientImplTest.java +++ b/server/src/test/java/org/apache/druid/client/coordinator/CoordinatorClientImplTest.java @@ -21,13 +21,18 @@ import com.fasterxml.jackson.databind.InjectableValues; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.rpc.MockServiceClient; import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.NumberedShardSpec; import org.jboss.netty.handler.codec.http.HttpMethod; @@ -42,6 +47,7 @@ import javax.ws.rs.core.MediaType; import java.util.Collections; import java.util.List; +import java.util.Set; public class CoordinatorClientImplTest { @@ -130,7 +136,10 @@ public void test_fetchSegment() throws Exception .build(); serviceClient.expectAndRespond( - new RequestBuilder(HttpMethod.GET, "/druid/coordinator/v1/metadata/datasources/xyz/segments/def?includeUnused=true"), + new RequestBuilder( + HttpMethod.GET, + "/druid/coordinator/v1/metadata/datasources/xyz/segments/def?includeUnused=true" + ), HttpResponseStatus.OK, ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), jsonMapper.writeValueAsBytes(segment) @@ -168,4 +177,71 @@ public void test_fetchUsedSegments() throws Exception coordinatorClient.fetchUsedSegments("xyz", intervals).get() ); } + + @Test + public void test_fetchServerViewSegments() throws Exception + { + + final List intervals = ImmutableList.of( + Intervals.of("2001/2002"), + Intervals.of("2501/2502") + ); + + final Set serverMetadataSet = + ImmutableSet.of( + new DruidServerMetadata( + "TEST_SERVER", + "testhost:9092", + null, + 1, + ServerType.INDEXER_EXECUTOR, + "tier1", + 0 + ) + ); + + final ImmutableSegmentLoadInfo immutableSegmentLoadInfo1 = new ImmutableSegmentLoadInfo( + DataSegment.builder() + .dataSource("xyz") + .interval(intervals.get(0)) + .version("1") + .shardSpec(new NumberedShardSpec(0, 1)) + .size(1) + .build(), + serverMetadataSet + ); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/druid/coordinator/v1/datasources/xyz/intervals/2001-01-01T00:00:00.000Z_2002-01-01T00:00:00.000Z/serverview?full"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(Collections.singletonList(immutableSegmentLoadInfo1)) + ); + + final ImmutableSegmentLoadInfo immutableSegmentLoadInfo2 = new ImmutableSegmentLoadInfo( + DataSegment.builder() + .dataSource("xyz") + .interval(intervals.get(1)) + .version("1") + .shardSpec(new NumberedShardSpec(0, 1)) + .size(1) + .build(), + serverMetadataSet + ); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/druid/coordinator/v1/datasources/xyz/intervals/2501-01-01T00:00:00.000Z_2502-01-01T00:00:00.000Z/serverview?full"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(Collections.singletonList(immutableSegmentLoadInfo2)) + ); + + List segmentLoadInfoList = + ImmutableList.of(immutableSegmentLoadInfo1, immutableSegmentLoadInfo2); + + Assert.assertEquals( + segmentLoadInfoList, + coordinatorClient.fetchServerViewSegments("xyz", intervals) + ); + } } diff --git a/server/src/test/java/org/apache/druid/client/coordinator/NoopCoordinatorClient.java b/server/src/test/java/org/apache/druid/client/coordinator/NoopCoordinatorClient.java index 76e6346d3808..1bc23b48a478 100644 --- a/server/src/test/java/org/apache/druid/client/coordinator/NoopCoordinatorClient.java +++ b/server/src/test/java/org/apache/druid/client/coordinator/NoopCoordinatorClient.java @@ -20,6 +20,7 @@ package org.apache.druid.client.coordinator; import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.rpc.ServiceRetryPolicy; import org.apache.druid.timeline.DataSegment; @@ -41,6 +42,12 @@ public ListenableFuture fetchSegment(String dataSource, String segm throw new UnsupportedOperationException(); } + @Override + public Iterable fetchServerViewSegments(String dataSource, List intervals) + { + throw new UnsupportedOperationException(); + } + @Override public ListenableFuture> fetchUsedSegments(String dataSource, List intervals) { diff --git a/server/src/test/java/org/apache/druid/discovery/DataServerClientTest.java b/server/src/test/java/org/apache/druid/discovery/DataServerClientTest.java new file mode 100644 index 000000000000..7be5a13474d5 --- /dev/null +++ b/server/src/test/java/org/apache/druid/discovery/DataServerClientTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.discovery; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.context.DefaultResponseContext; +import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.query.scan.ScanResultValue; +import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; +import org.apache.druid.rpc.MockServiceClient; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; +import java.util.Collections; + +import static org.apache.druid.query.Druids.newScanQueryBuilder; +import static org.mockito.Mockito.mock; + +public class DataServerClientTest +{ + MockServiceClient serviceClient; + ServiceClientFactory serviceClientFactory; + ObjectMapper jsonMapper; + ScanQuery query; + DataServerClient target; + + @Before + public void setUp() + { + jsonMapper = DruidServiceTestUtils.newJsonMapper(); + serviceClient = new MockServiceClient(); + serviceClientFactory = (serviceName, serviceLocator, retryPolicy) -> serviceClient; + + query = newScanQueryBuilder() + .dataSource("dataSource1") + .intervals( + new MultipleSpecificSegmentSpec( + ImmutableList.of( + new SegmentDescriptor(Intervals.of("2003/2004"), "v0", 1) + ) + ) + ) + .columns("__time", "cnt", "dim1", "dim2", "m1", "m2", "unique_dim1") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .build(); + + target = new DataServerClient( + serviceClientFactory, + mock(ServiceLocation.class), + jsonMapper, + Execs.scheduledSingleThreaded("query-cancellation-executor") + ); + } + + @Test + public void testFetchSegmentFromDataServer() throws JsonProcessingException + { + ScanResultValue scanResultValue = new ScanResultValue( + null, + ImmutableList.of("id", "name"), + ImmutableList.of( + ImmutableList.of(1, "abc"), + ImmutableList.of(5, "efg") + )); + + RequestBuilder requestBuilder = new RequestBuilder(HttpMethod.POST, "/druid/v2/") + .jsonContent(jsonMapper, query); + serviceClient.expectAndRespond( + requestBuilder, + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(Collections.singletonList(scanResultValue)) + ); + + ResponseContext responseContext = new DefaultResponseContext(); + Sequence result = target.run( + query, + responseContext, + jsonMapper.getTypeFactory().constructType(ScanResultValue.class), + Closer.create() + ); + + Assert.assertEquals(ImmutableList.of(scanResultValue), result.toList()); + } +} diff --git a/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java b/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java new file mode 100644 index 000000000000..e366f6030346 --- /dev/null +++ b/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.rpc; + +import com.google.common.collect.ImmutableSet; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.ExecutionException; + +public class FixedSetServiceLocatorTest +{ + public static final DruidServerMetadata DATA_SERVER_1 = new DruidServerMetadata( + "TestDataServer", + "hostName:9092", + null, + 2, + ServerType.REALTIME, + "tier1", + 2 + ); + + @Test + public void testLocateNullShouldBeClosed() throws ExecutionException, InterruptedException + { + FixedSetServiceLocator serviceLocator + = FixedSetServiceLocator.forDruidServerMetadata(null); + + Assert.assertTrue(serviceLocator.locate().get().isClosed()); + } + + + @Test + public void testLocateSingleServer() throws ExecutionException, InterruptedException + { + FixedSetServiceLocator serviceLocator + = FixedSetServiceLocator.forDruidServerMetadata(ImmutableSet.of(DATA_SERVER_1)); + + Assert.assertEquals( + ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)), + serviceLocator.locate().get() + ); + } +} diff --git a/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java b/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java index 3fba0c409e08..6aec0e2b6060 100644 --- a/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java +++ b/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java @@ -20,10 +20,51 @@ package org.apache.druid.rpc; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.junit.Assert; import org.junit.Test; public class ServiceLocationTest { + @Test + public void test_fromDruidServerMetadata_withPort() + { + DruidServerMetadata druidServerMetadata = new DruidServerMetadata( + "name", + "hostName:9092", + null, + 1, + ServerType.INDEXER_EXECUTOR, + "tier1", + 2 + ); + + Assert.assertEquals( + new ServiceLocation("hostName", 9092, -1, ""), + ServiceLocation.fromDruidServerMetadata(druidServerMetadata) + ); + } + + @Test + public void test_fromDruidServerMetadata_withTlsPort() + { + DruidServerMetadata druidServerMetadata = new DruidServerMetadata( + "name", + null, + "hostName:8100", + 1, + ServerType.INDEXER_EXECUTOR, + "tier1", + 2 + ); + + Assert.assertEquals( + new ServiceLocation("hostName", -1, 8100, ""), + ServiceLocation.fromDruidServerMetadata(druidServerMetadata) + ); + } + @Test public void test_equals() { From 40a6dc4631364bfce3bb54bb292092575a349c0b Mon Sep 17 00:00:00 2001 From: AmatyaAvadhanula Date: Mon, 9 Oct 2023 17:54:13 +0530 Subject: [PATCH 09/14] Optimize used segment fetching in Kill tasks (#15107) * Optimize used segment fetching in Kill tasks --- .../common/task/KillUnusedSegmentsTask.java | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/KillUnusedSegmentsTask.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/KillUnusedSegmentsTask.java index cc760894603e..1726a3e68003 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/KillUnusedSegmentsTask.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/KillUnusedSegmentsTask.java @@ -222,18 +222,30 @@ public TaskStatus runTask(TaskToolbox toolbox) throws Exception toolbox.getTaskActionClient().submit(new SegmentNukeAction(new HashSet<>(unusedSegments))); - // Fetch the load specs of all segments overlapping with the given interval - final Set> usedSegmentLoadSpecs = toolbox - .getTaskActionClient() - .submit(new RetrieveUsedSegmentsAction(getDataSource(), getInterval(), null, Segments.INCLUDING_OVERSHADOWED)) - .stream() - .map(DataSegment::getLoadSpec) - .collect(Collectors.toSet()); + final Set unusedSegmentIntervals = unusedSegments.stream() + .map(DataSegment::getInterval) + .collect(Collectors.toSet()); + final Set> usedSegmentLoadSpecs = new HashSet<>(); + if (!unusedSegmentIntervals.isEmpty()) { + RetrieveUsedSegmentsAction retrieveUsedSegmentsAction = new RetrieveUsedSegmentsAction( + getDataSource(), + null, + unusedSegmentIntervals, + Segments.INCLUDING_OVERSHADOWED + ); + // Fetch the load specs of all segments overlapping with the unused segment intervals + usedSegmentLoadSpecs.addAll(toolbox.getTaskActionClient().submit(retrieveUsedSegmentsAction) + .stream() + .map(DataSegment::getLoadSpec) + .collect(Collectors.toSet()) + ); + } // Kill segments from the deep storage only if their load specs are not being used by any used segments final List segmentsToBeKilled = unusedSegments .stream() - .filter(unusedSegment -> !usedSegmentLoadSpecs.contains(unusedSegment.getLoadSpec())) + .filter(unusedSegment -> unusedSegment.getLoadSpec() == null + || !usedSegmentLoadSpecs.contains(unusedSegment.getLoadSpec())) .collect(Collectors.toList()); toolbox.getDataSegmentKiller().kill(segmentsToBeKilled); From 549ef5628845ca858f2ff0fda0007033662c2849 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 9 Oct 2023 18:18:15 +0530 Subject: [PATCH 10/14] UNION ALLs in MSQ (#14981) MSQ now supports UNION ALL with UnionDataSource --- .../druid/msq/querykit/DataSourcePlan.java | 61 ++ .../druid/msq/sql/MSQTaskSqlEngine.java | 1 + .../apache/druid/msq/exec/MSQFaultsTest.java | 44 ++ .../apache/druid/msq/exec/MSQSelectTest.java | 61 +- .../msq/test/CalciteUnionQueryMSQTest.java | 183 +++++ .../apache/druid/msq/test/MSQTestBase.java | 8 +- .../apache/druid/query/UnionDataSource.java | 73 +- .../apache/druid/query/UnionQueryRunner.java | 10 +- .../apache/druid/query/DataSourceTest.java | 2 +- .../druid/query/UnionDataSourceTest.java | 2 +- .../calcite/rel/DruidUnionDataSourceRel.java | 2 +- .../druid/sql/calcite/rule/DruidRules.java | 7 +- .../sql/calcite/rule/DruidSortUnionRule.java | 3 +- .../rule/DruidUnionDataSourceRule.java | 7 +- .../sql/calcite/rule/DruidUnionRule.java | 5 + .../druid/sql/calcite/run/EngineFeature.java | 18 +- .../sql/calcite/run/NativeSqlEngine.java | 1 + .../druid/sql/calcite/view/ViewSqlEngine.java | 1 + .../druid/sql/calcite/CalciteQueryTest.java | 625 +++--------------- .../sql/calcite/CalciteUnionQueryTest.java | 405 ++++++++++++ .../sql/calcite/IngestionTestSqlEngine.java | 1 + .../planner/CalcitePlannerModuleTest.java | 11 +- 22 files changed, 962 insertions(+), 569 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java create mode 100644 sql/src/test/java/org/apache/druid/sql/calcite/CalciteUnionQueryTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index d8481bf7a094..16eaef63c497 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -51,6 +51,7 @@ import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.UnionDataSource; import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.planning.DataSourceAnalysis; @@ -170,6 +171,18 @@ public static DataSourcePlan forDataSource( minStageNumber, broadcast ); + } else if (dataSource instanceof UnionDataSource) { + return forUnion( + queryKit, + queryId, + queryContext, + (UnionDataSource) dataSource, + querySegmentSpec, + filter, + maxWorkerCount, + minStageNumber, + broadcast + ); } else if (dataSource instanceof JoinDataSource) { final JoinAlgorithm preferredJoinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext); final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm( @@ -458,6 +471,54 @@ private static DataSourcePlan forUnnest( ); } + private static DataSourcePlan forUnion( + final QueryKit queryKit, + final String queryId, + final QueryContext queryContext, + final UnionDataSource unionDataSource, + final QuerySegmentSpec querySegmentSpec, + @Nullable DimFilter filter, + final int maxWorkerCount, + final int minStageNumber, + final boolean broadcast + ) + { + // This is done to prevent loss of generality since MSQ can plan any type of DataSource. + List children = unionDataSource.getDataSources(); + + final QueryDefinitionBuilder subqueryDefBuilder = QueryDefinition.builder(); + final List newChildren = new ArrayList<>(); + final List inputSpecs = new ArrayList<>(); + final IntSet broadcastInputs = new IntOpenHashSet(); + + for (DataSource child : children) { + DataSourcePlan childDataSourcePlan = forDataSource( + queryKit, + queryId, + queryContext, + child, + querySegmentSpec, + filter, + maxWorkerCount, + Math.max(minStageNumber, subqueryDefBuilder.getNextStageNumber()), + broadcast + ); + + int shift = inputSpecs.size(); + + newChildren.add(shiftInputNumbers(childDataSourcePlan.getNewDataSource(), shift)); + inputSpecs.addAll(childDataSourcePlan.getInputSpecs()); + childDataSourcePlan.getSubQueryDefBuilder().ifPresent(subqueryDefBuilder::addAll); + childDataSourcePlan.getBroadcastInputs().forEach(inp -> broadcastInputs.add(inp + shift)); + } + return new DataSourcePlan( + new UnionDataSource(newChildren), + inputSpecs, + broadcastInputs, + subqueryDefBuilder + ); + } + /** * Build a plan for broadcast hash-join. */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java index e6578388a40e..cb331760ca34 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java @@ -114,6 +114,7 @@ public boolean featureAvailable(EngineFeature feature, PlannerContext plannerCon case TIME_BOUNDARY_QUERY: case GROUPING_SETS: case WINDOW_FUNCTIONS: + case ALLOW_TOP_LEVEL_UNION_ALL: return false; case UNNEST: case CAN_SELECT: diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java index 42bb1506a307..4b77dd78b339 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java @@ -20,6 +20,8 @@ package org.apache.druid.msq.exec; import com.google.common.collect.ImmutableMap; +import org.apache.druid.error.DruidException; +import org.apache.druid.error.DruidExceptionMatcher; import org.apache.druid.indexing.common.actions.SegmentAllocateAction; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; @@ -330,4 +332,46 @@ public void testTooManyInputFiles() throws IOException .setExpectedMSQFault(new TooManyInputFilesFault(numFiles, Limits.MAX_INPUT_FILES_PER_WORKER, 2)) .verifyResults(); } + + @Test + public void testUnionAllWithDifferentColumnNames() + { + // This test fails till MSQ can support arbitrary column names and column types for UNION ALL + testIngestQuery() + .setSql( + "INSERT INTO druid.dst " + + "SELECT dim2, dim1, m1 FROM foo2 " + + "UNION ALL " + + "SELECT dim1, dim2, m1 FROM foo " + + "PARTITIONED BY ALL TIME") + .setExpectedValidationErrorMatcher( + new DruidExceptionMatcher( + DruidException.Persona.ADMIN, + DruidException.Category.INVALID_INPUT, + "general" + ).expectMessageContains("SQL requires union between two tables and column names queried for each table are different " + + "Left: [dim2, dim1, m1], Right: [dim1, dim2, m1].")) + .verifyPlanningErrors(); + } + + @Test + public void testTopLevelUnionAllWithJoins() + { + // This test fails becaues it is a top level UNION ALL which cannot be planned using MSQ. It will be supported once + // we support arbitrary types and column names for UNION ALL + testSelectQuery() + .setSql( + "(SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) " + + "UNION ALL " + + "(SELECT SUM(cnt) FROM foo)" + ) + .setExpectedValidationErrorMatcher( + new DruidExceptionMatcher( + DruidException.Persona.ADMIN, + DruidException.Category.INVALID_INPUT, + "general" + ).expectMessageContains( + "SQL requires union between inputs that are not simple table scans and involve a filter or aliasing")) + .verifyPlanningErrors(); + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index ac9ca855a635..d771f7497a8c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -51,6 +51,7 @@ import org.apache.druid.query.Query; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.UnionDataSource; import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; @@ -1929,8 +1930,8 @@ public void testGroupByOnFooWithDurableStoragePathAssertions() throws IOExceptio new ColumnMappings(ImmutableList.of( new ColumnMapping("d0", "cnt"), new ColumnMapping("a0", "cnt1") - ) )) + ) .tuningConfig(MSQTuningConfig.defaultConfig()) .destination(isDurableStorageDestination() ? DurableStorageMSQDestination.INSTANCE @@ -2322,6 +2323,64 @@ public void testSelectUnnestOnQueryFoo() .verifyResults(); } + @Test + public void testUnionAllUsingUnionDataSource() + { + + final RowSignature rowSignature = RowSignature.builder() + .add("__time", ColumnType.LONG) + .add("dim1", ColumnType.STRING) + .build(); + + final List results = ImmutableList.of( + new Object[]{946684800000L, ""}, + new Object[]{946684800000L, ""}, + new Object[]{946771200000L, "10.1"}, + new Object[]{946771200000L, "10.1"}, + new Object[]{946857600000L, "2"}, + new Object[]{946857600000L, "2"}, + new Object[]{978307200000L, "1"}, + new Object[]{978307200000L, "1"}, + new Object[]{978393600000L, "def"}, + new Object[]{978393600000L, "def"}, + new Object[]{978480000000L, "abc"}, + new Object[]{978480000000L, "abc"} + ); + // This plans the query using DruidUnionDataSourceRule since the DruidUnionDataSourceRule#isCompatible + // returns true (column names, types match, and it is a union on the table data sources). + // It gets planned correctly, however MSQ engine cannot plan the query correctly + testSelectQuery() + .setSql("SELECT __time, dim1 FROM foo\n" + + "UNION ALL\n" + + "SELECT __time, dim1 FROM foo\n") + .setExpectedRowSignature(rowSignature) + .setExpectedMSQSpec( + MSQSpec.builder() + .query(newScanQueryBuilder() + .dataSource(new UnionDataSource( + ImmutableList.of(new TableDataSource("foo"), new TableDataSource("foo")) + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(defaultScanQueryContext( + context, + rowSignature + )) + .columns(ImmutableList.of("__time", "dim1")) + .build()) + .columnMappings(ColumnMappings.identity(rowSignature)) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(isDurableStorageDestination() + ? DurableStorageMSQDestination.INSTANCE + : TaskReportMSQDestination.INSTANCE) + .build() + ) + .setQueryContext(context) + .setExpectedResultRows(results) + .verifyResults(); + } + @Nonnull private List expectedMultiValueFooRowsGroup() { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java new file mode 100644 index 000000000000..8ee9e78c8388 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.test; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.inject.Injector; +import com.google.inject.Module; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.guice.DruidInjectorBuilder; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.sql.MSQTaskSqlEngine; +import org.apache.druid.query.QueryDataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.TestGroupByBuffers; +import org.apache.druid.server.QueryLifecycleFactory; +import org.apache.druid.sql.calcite.BaseCalciteQueryTest; +import org.apache.druid.sql.calcite.CalciteUnionQueryTest; +import org.apache.druid.sql.calcite.QueryTestBuilder; +import org.apache.druid.sql.calcite.filtration.Filtration; +import org.apache.druid.sql.calcite.run.SqlEngine; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Runs {@link CalciteUnionQueryTest} but with MSQ engine + */ +public class CalciteUnionQueryMSQTest extends CalciteUnionQueryTest +{ + private TestGroupByBuffers groupByBuffers; + + @Before + public void setup2() + { + groupByBuffers = TestGroupByBuffers.createDefault(); + } + + @After + public void teardown2() + { + groupByBuffers.close(); + } + + @Override + public void configureGuice(DruidInjectorBuilder builder) + { + super.configureGuice(builder); + builder.addModules(CalciteMSQTestsHelper.fetchModules(temporaryFolder, groupByBuffers).toArray(new Module[0])); + } + + + @Override + public SqlEngine createEngine( + QueryLifecycleFactory qlf, + ObjectMapper queryJsonMapper, + Injector injector + ) + { + final WorkerMemoryParameters workerMemoryParameters = + WorkerMemoryParameters.createInstance( + WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, + 2, + 10, + 2, + 0, + 0 + ); + final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( + queryJsonMapper, + injector, + new MSQTestTaskActionClient(queryJsonMapper), + workerMemoryParameters + ); + return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper); + } + + @Override + protected QueryTestBuilder testBuilder() + { + return new QueryTestBuilder(new BaseCalciteQueryTest.CalciteTestConfig(true)) + .addCustomRunner(new ExtractResultsFactory(() -> (MSQTestOverlordServiceClient) ((MSQTaskSqlEngine) queryFramework().engine()).overlordClient())) + .skipVectorize(true) + .verifyNativeQueries(new VerifyMSQSupportedNativeQueriesPredicate()) + .msqCompatible(msqCompatible); + } + + /** + * Generates a different error hint than what is required by the native engine, since planner does try to plan "UNION" + * using group by, however fails due to the column name mismatch. + * MSQ does wnat to support any type of data source, with least restrictive column names and types, therefore it + * should eventually work. + */ + @Test + @Override + public void testUnionIsUnplannable() + { + assertQueryIsUnplannable( + "SELECT dim2, dim1, m1 FROM foo2 UNION SELECT dim1, dim2, m1 FROM foo", + "SQL requires union between two tables and column names queried for each table are different Left: [dim2, dim1, m1], Right: [dim1, dim2, m1]." + ); + + } + + @Ignore("Ignored till MSQ can plan UNION ALL with any operand") + @Test + public void testUnionOnSubqueries() + { + testQuery( + "SELECT\n" + + " SUM(cnt),\n" + + " COUNT(*)\n" + + "FROM (\n" + + " (SELECT dim2, SUM(cnt) AS cnt FROM druid.foo GROUP BY dim2)\n" + + " UNION ALL\n" + + " (SELECT dim2, SUM(cnt) AS cnt FROM druid.foo GROUP BY dim2)\n" + + ")", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE1) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("_a0", "a0"), + new CountAggregatorFactory("_a1") + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() ? + ImmutableList.of( + new Object[]{12L, 3L} + ) : + ImmutableList.of( + new Object[]{12L, 4L} + ) + ); + } + +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 4d97b911a091..31ece253ebd4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -1412,9 +1412,11 @@ public Pair, List>> public void verifyResults() { - Preconditions.checkArgument(expectedResultRows != null, "Result rows cannot be null"); - Preconditions.checkArgument(expectedRowSignature != null, "Row signature cannot be null"); - Preconditions.checkArgument(expectedMSQSpec != null, "MultiStageQuery Query spec cannot be null "); + if (expectedMSQFault == null) { + Preconditions.checkArgument(expectedResultRows != null, "Result rows cannot be null"); + Preconditions.checkArgument(expectedRowSignature != null, "Row signature cannot be null"); + Preconditions.checkArgument(expectedMSQSpec != null, "MultiStageQuery Query spec cannot be null "); + } Pair, List>> specAndResults = runQueryWithResult(); if (specAndResults == null) { // A fault was expected and the assertion has been done in the runQueryWithResult diff --git a/processing/src/main/java/org/apache/druid/query/UnionDataSource.java b/processing/src/main/java/org/apache/druid/query/UnionDataSource.java index 3f538f5ad5aa..27a0113d76f1 100644 --- a/processing/src/main/java/org/apache/druid/query/UnionDataSource.java +++ b/processing/src/main/java/org/apache/druid/query/UnionDataSource.java @@ -23,11 +23,12 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.segment.SegmentReference; +import org.apache.druid.utils.CollectionUtils; import java.util.Collections; import java.util.List; @@ -36,13 +37,24 @@ import java.util.function.Function; import java.util.stream.Collectors; +/** + * Reperesents a UNION ALL of two or more datasources. + * + * Native engine can only work with table datasources that are scans or simple mappings (column rename without any + * expression applied on top). Therefore, it uses methods like {@link #getTableNames()} and + * {@link #getDataSourcesAsTableDataSources()} to assert that the children were TableDataSources. + * + * MSQ should be able to plan and work with arbitrary datasources. It also needs to replace the datasource with the + * InputNumberDataSource while preparing the query plan. + */ public class UnionDataSource implements DataSource { - @JsonProperty - private final List dataSources; + + @JsonProperty("dataSources") + private final List dataSources; @JsonCreator - public UnionDataSource(@JsonProperty("dataSources") List dataSources) + public UnionDataSource(@JsonProperty("dataSources") List dataSources) { if (dataSources == null || dataSources.isEmpty()) { throw new ISE("'dataSources' must be non-null and non-empty for 'union'"); @@ -51,18 +63,45 @@ public UnionDataSource(@JsonProperty("dataSources") List dataSo this.dataSources = dataSources; } + public List getDataSources() + { + return dataSources; + } + + + /** + * Asserts that the children of the union are all table data sources before returning the table names + */ @Override public Set getTableNames() { - return dataSources.stream() - .map(input -> Iterables.getOnlyElement(input.getTableNames())) - .collect(Collectors.toSet()); + return dataSources + .stream() + .map(input -> { + if (!(input instanceof TableDataSource)) { + throw DruidException.defensive("should be table"); + } + return CollectionUtils.getOnlyElement( + input.getTableNames(), + xs -> DruidException.defensive("Expected only single table name in TableDataSource") + ); + }) + .collect(Collectors.toSet()); } - @JsonProperty - public List getDataSources() + /** + * Asserts that the children of the union are all table data sources + */ + public List getDataSourcesAsTableDataSources() { - return dataSources; + return dataSources.stream() + .map(input -> { + if (!(input instanceof TableDataSource)) { + throw DruidException.defensive("should be table"); + } + return (TableDataSource) input; + }) + .collect(Collectors.toList()); } @Override @@ -78,13 +117,7 @@ public DataSource withChildren(List children) throw new IAE("Expected [%d] children, got [%d]", dataSources.size(), children.size()); } - if (!children.stream().allMatch(dataSource -> dataSource instanceof TableDataSource)) { - throw new IAE("All children must be tables"); - } - - return new UnionDataSource( - children.stream().map(dataSource -> (TableDataSource) dataSource).collect(Collectors.toList()) - ); + return new UnionDataSource(children); } @Override @@ -149,11 +182,7 @@ public boolean equals(Object o) UnionDataSource that = (UnionDataSource) o; - if (!dataSources.equals(that.dataSources)) { - return false; - } - - return true; + return dataSources.equals(that.dataSources); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/UnionQueryRunner.java b/processing/src/main/java/org/apache/druid/query/UnionQueryRunner.java index aeb3897e644b..5459e1d8c22e 100644 --- a/processing/src/main/java/org/apache/druid/query/UnionQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/UnionQueryRunner.java @@ -57,16 +57,16 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext respo final UnionDataSource unionDataSource = analysis.getBaseUnionDataSource().get(); - if (unionDataSource.getDataSources().isEmpty()) { + if (unionDataSource.getDataSourcesAsTableDataSources().isEmpty()) { // Shouldn't happen, because UnionDataSource doesn't allow empty unions. throw new ISE("Unexpectedly received empty union"); - } else if (unionDataSource.getDataSources().size() == 1) { + } else if (unionDataSource.getDataSourcesAsTableDataSources().size() == 1) { // Single table. Run as a normal query. return baseRunner.run( queryPlus.withQuery( Queries.withBaseDataSource( query, - Iterables.getOnlyElement(unionDataSource.getDataSources()) + Iterables.getOnlyElement(unionDataSource.getDataSourcesAsTableDataSources()) ) ), responseContext @@ -77,8 +77,8 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext respo query.getResultOrdering(), Sequences.simple( Lists.transform( - IntStream.range(0, unionDataSource.getDataSources().size()) - .mapToObj(i -> new Pair<>(unionDataSource.getDataSources().get(i), i + 1)) + IntStream.range(0, unionDataSource.getDataSourcesAsTableDataSources().size()) + .mapToObj(i -> new Pair<>(unionDataSource.getDataSourcesAsTableDataSources().get(i), i + 1)) .collect(Collectors.toList()), (Function, Sequence>) singleSourceWithIndex -> baseRunner.run( diff --git a/processing/src/test/java/org/apache/druid/query/DataSourceTest.java b/processing/src/test/java/org/apache/druid/query/DataSourceTest.java index 7c7f50f281bb..e7850953a609 100644 --- a/processing/src/test/java/org/apache/druid/query/DataSourceTest.java +++ b/processing/src/test/java/org/apache/druid/query/DataSourceTest.java @@ -89,7 +89,7 @@ public void testUnionDataSource() throws Exception Assert.assertTrue(dataSource instanceof UnionDataSource); Assert.assertEquals( Lists.newArrayList(new TableDataSource("ds1"), new TableDataSource("ds2")), - Lists.newArrayList(((UnionDataSource) dataSource).getDataSources()) + Lists.newArrayList(((UnionDataSource) dataSource).getDataSourcesAsTableDataSources()) ); Assert.assertEquals( ImmutableSet.of("ds1", "ds2"), diff --git a/processing/src/test/java/org/apache/druid/query/UnionDataSourceTest.java b/processing/src/test/java/org/apache/druid/query/UnionDataSourceTest.java index f408e71abf23..12522df08df3 100644 --- a/processing/src/test/java/org/apache/druid/query/UnionDataSourceTest.java +++ b/processing/src/test/java/org/apache/druid/query/UnionDataSourceTest.java @@ -123,7 +123,7 @@ public void test_withChildren_empty() @Test public void test_withChildren_sameNumber() { - final List newDataSources = ImmutableList.of( + final List newDataSources = ImmutableList.of( new TableDataSource("baz"), new TableDataSource("qux") ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java index 5e213de711cc..dbbcfa0f9a3b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java @@ -118,7 +118,7 @@ public DruidUnionDataSourceRel withPartialQuery(final PartialDruidQuery newQuery @Override public DruidQuery toDruidQuery(final boolean finalizeAggregations) { - final List dataSources = new ArrayList<>(); + final List dataSources = new ArrayList<>(); RowSignature signature = null; for (final RelNode relNode : unionRel.getInputs()) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java index 8ca4ab076d9c..dfcf1652c0de 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java @@ -95,13 +95,16 @@ public static List rules(PlannerContext plannerContext) DruidOuterQueryRule.WHERE_FILTER, DruidOuterQueryRule.SELECT_PROJECT, DruidOuterQueryRule.SORT, - new DruidUnionRule(plannerContext), + new DruidUnionRule(plannerContext), // Add top level union rule since it helps in constructing a cleaner error message for the user new DruidUnionDataSourceRule(plannerContext), - DruidSortUnionRule.instance(), DruidJoinRule.instance(plannerContext) ) ); + if (plannerContext.featureAvailable(EngineFeature.ALLOW_TOP_LEVEL_UNION_ALL)) { + retVal.add(DruidSortUnionRule.instance()); + } + if (plannerContext.featureAvailable(EngineFeature.WINDOW_FUNCTIONS)) { retVal.add(new DruidQueryRule<>(Window.class, PartialDruidQuery.Stage.WINDOW, PartialDruidQuery::withWindow)); retVal.add( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java index daf1162ac44d..d06c39d72b5b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java @@ -32,8 +32,9 @@ */ public class DruidSortUnionRule extends RelOptRule { - private static final DruidSortUnionRule INSTANCE = new DruidSortUnionRule(); + private static final DruidSortUnionRule INSTANCE = new DruidSortUnionRule(); + private DruidSortUnionRule() { super(operand(Sort.class, operand(DruidUnionRel.class, any()))); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java index 99f6248b37d5..e4a72776315d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java @@ -112,7 +112,12 @@ public void onMatch(final RelOptRuleCall call) // Can only do UNION ALL of inputs that have compatible schemas (or schema mappings) and right side // is a simple table scan - public static boolean isCompatible(final Union unionRel, final DruidRel first, final DruidRel second, @Nullable PlannerContext plannerContext) + public static boolean isCompatible( + final Union unionRel, + final DruidRel first, + final DruidRel second, + @Nullable PlannerContext plannerContext + ) { if (!(second instanceof DruidQueryRel)) { return false; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java index 40cb2161c155..58fddbf933f3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java @@ -26,6 +26,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.DruidRel; import org.apache.druid.sql.calcite.rel.DruidUnionRel; +import org.apache.druid.sql.calcite.run.EngineFeature; import java.util.List; @@ -51,6 +52,10 @@ public DruidUnionRule(PlannerContext plannerContext) @Override public boolean matches(RelOptRuleCall call) { + if (plannerContext != null && !plannerContext.featureAvailable(EngineFeature.ALLOW_TOP_LEVEL_UNION_ALL)) { + plannerContext.setPlanningError("Queries cannot be planned using top level union all"); + return false; + } // Make DruidUnionRule and DruidUnionDataSourceRule mutually exclusive. final Union unionRel = call.rel(0); final DruidRel firstDruidRel = call.rel(1); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/EngineFeature.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/EngineFeature.java index 94827c2955da..778c7ec03b6f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/EngineFeature.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/EngineFeature.java @@ -102,5 +102,21 @@ public enum EngineFeature * that it actually *does* generate correct results in native when the join is processed on the Broker. It is much * less likely that MSQ will plan in such a way that generates correct results. */ - ALLOW_BROADCAST_RIGHTY_JOIN; + ALLOW_BROADCAST_RIGHTY_JOIN, + + /** + * Planner is permitted to use {@link org.apache.druid.sql.calcite.rel.DruidUnionRel} to plan the top level UNION ALL. + * This is to dissuade planner from accepting and running the UNION ALL queries that are not supported by engines + * (primarily MSQ). + * + * Due to the nature of the exeuction of the top level UNION ALLs (we run the individual queries and concat the + * results), it only makes sense to enable this on engines where the queries return the results synchronously + * + * Planning queries with top level UNION_ALL leads to undesirable behaviour with asynchronous engines like MSQ. + * To enumerate this behaviour for MSQ, the broker attempts to run the individual queries as MSQ queries in succession, + * submits the first query correctly, fails on the rest of the queries (due to conflicting taskIds), + * and cannot concat the results together (as * the result for broker is the query id). Therefore, we don't get the + * correct result back, while the MSQ engine is executing the partial query + */ + ALLOW_TOP_LEVEL_UNION_ALL; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java index d7fc7d043b6f..164e02a0ca8d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java @@ -105,6 +105,7 @@ public boolean featureAvailable(EngineFeature feature, PlannerContext plannerCon case WINDOW_FUNCTIONS: case UNNEST: case ALLOW_BROADCAST_RIGHTY_JOIN: + case ALLOW_TOP_LEVEL_UNION_ALL: return true; case TIME_BOUNDARY_QUERY: return plannerContext.queryContext().isTimeBoundaryPlanningEnabled(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java index cd719d7f29fd..e2ce813a37f7 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java @@ -63,6 +63,7 @@ public boolean featureAvailable(EngineFeature feature, PlannerContext plannerCon case GROUPING_SETS: case WINDOW_FUNCTIONS: case UNNEST: + case ALLOW_TOP_LEVEL_UNION_ALL: return true; // Views can't sit on top of INSERT or REPLACE. case CAN_INSERT: diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 4042def2750d..d49f7de9dd5f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -2874,476 +2874,6 @@ public void testUnionAllQueriesWithLimit() ); } - @DecoupledIgnore - @Test - public void testUnionAllDifferentTablesWithMapping() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM numfoo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE3) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 2.0, 2L}, - new Object[]{"1", "a", 8.0, 2L} - ) - ); - } - - @DecoupledIgnore(mode = Modes.NOT_ENOUGH_RULES) - @Test - public void testJoinUnionAllDifferentTablesWithMapping() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM numfoo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE3) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 2.0, 2L}, - new Object[]{"1", "a", 8.0, 2L} - ) - ); - } - - @Test - public void testUnionAllTablesColumnCountMismatch() - { - try { - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM numfoo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of(), - ImmutableList.of() - ); - Assert.fail("query execution should fail"); - } - catch (DruidException e) { - MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [42])")); - } - } - - @DecoupledIgnore(mode = Modes.NOT_ENOUGH_RULES) - @Test - public void testUnionAllTablesColumnTypeMismatchFloatLong() - { - msqIncompatible(); - // "m1" has a different type in foo and foo2 (float vs long), but this query is OK anyway because they can both - // be implicitly cast to double. - - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim1, dim2, m1 FROM foo2 UNION ALL SELECT dim1, dim2, m1 FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'en'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE2), - new TableDataSource(CalciteTests.DATASOURCE1) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("en", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 1.0, 1L}, - new Object[]{"1", "a", 4.0, 1L}, - new Object[]{"druid", "en", 1.0, 1L} - ) - ); - } - - @DecoupledIgnore(mode = Modes.ERROR_HANDLING) - @Test - public void testUnionAllTablesColumnTypeMismatchStringLong() - { - // "dim3" has a different type in foo and foo2 (string vs long), which requires a casting subquery, so this - // query cannot be planned. - - assertQueryIsUnplannable( - "SELECT\n" - + "dim3, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim3, dim2, m1 FROM foo2 UNION ALL SELECT dim3, dim2, m1 FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'en'\n" - + "GROUP BY 1, 2", - "SQL requires union between inputs that are not simple table scans and involve a " + - "filter or aliasing. Or column types of tables being unioned are not of same type." - ); - } - - @DecoupledIgnore(mode = Modes.ERROR_HANDLING) - @Test - public void testUnionAllTablesWhenMappingIsRequired() - { - // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery. - - assertQueryIsUnplannable( - "SELECT\n" - + "c, COUNT(*)\n" - + "FROM (SELECT dim1 AS c, m1 FROM foo UNION ALL SELECT dim2 AS c, m1 FROM numfoo)\n" - + "WHERE c = 'a' OR c = 'def'\n" - + "GROUP BY 1", - "SQL requires union between two tables " + - "and column names queried for each table are different Left: [dim1], Right: [dim2]." - ); - } - - @DecoupledIgnore(mode = Modes.ERROR_HANDLING) - @Test - public void testUnionIsUnplannable() - { - // Cannot plan this UNION operation - assertQueryIsUnplannable( - "SELECT dim2, dim1, m1 FROM foo2 UNION SELECT dim1, dim2, m1 FROM foo", - "SQL requires 'UNION' but only 'UNION ALL' is supported." - ); - } - - @DecoupledIgnore(mode = Modes.ERROR_HANDLING) - @Test - public void testUnionAllTablesWhenCastAndMappingIsRequired() - { - // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery. - assertQueryIsUnplannable( - "SELECT\n" - + "c, COUNT(*)\n" - + "FROM (SELECT dim1 AS c, m1 FROM foo UNION ALL SELECT cnt AS c, m1 FROM numfoo)\n" - + "WHERE c = 'a' OR c = 'def'\n" - + "GROUP BY 1", - "SQL requires union between inputs that are not simple table scans and involve " + - "a filter or aliasing. Or column types of tables being unioned are not of same type." - ); - } - - @DecoupledIgnore - @Test - public void testUnionAllSameTableTwice() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 2.0, 2L}, - new Object[]{"1", "a", 8.0, 2L} - ) - ); - } - - @DecoupledIgnore(mode = Modes.NOT_ENOUGH_RULES) - @Test - public void testUnionAllSameTableTwiceWithSameMapping() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 2.0, 2L}, - new Object[]{"1", "a", 8.0, 2L} - ) - ); - } - - @DecoupledIgnore(mode = Modes.ERROR_HANDLING) - @Test - public void testUnionAllSameTableTwiceWithDifferentMapping() - { - // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery. - assertQueryIsUnplannable( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim2, dim1, m1 FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - "SQL requires union between two tables and column names queried for each table are different Left: [dim1, dim2, m1], Right: [dim2, dim1, m1]." - ); - } - @DecoupledIgnore - @Test - public void testUnionAllSameTableThreeTimes() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo UNION ALL SELECT * FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 3.0, 3L}, - new Object[]{"1", "a", 12.0, 3L} - ) - ); - } - - @Test - public void testUnionAllThreeTablesColumnCountMismatch1() - { - try { - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT * FROM numfoo UNION ALL SELECT * FROM foo UNION ALL SELECT * from foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of(), - ImmutableList.of() - ); - Assert.fail("query execution should fail"); - } - catch (DruidException e) { - MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [45])")); - } - } - - @Test - public void testUnionAllThreeTablesColumnCountMismatch2() - { - try { - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT * FROM numfoo UNION ALL SELECT * FROM foo UNION ALL SELECT * from foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of(), - ImmutableList.of() - ); - Assert.fail("query execution should fail"); - } - catch (DruidException e) { - MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [45])")); - } - } - - @Test - public void testUnionAllThreeTablesColumnCountMismatch3() - { - try { - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo UNION ALL SELECT * from numfoo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of(), - ImmutableList.of() - ); - Assert.fail("query execution should fail"); - } - catch (DruidException e) { - MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [70])")); - } - } - - @DecoupledIgnore - @Test - public void testUnionAllSameTableThreeTimesWithSameMapping() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + "dim1, dim2, SUM(m1), COUNT(*)\n" - + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo)\n" - + "WHERE dim2 = 'a' OR dim2 = 'def'\n" - + "GROUP BY 1, 2", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) - .setDimensions( - new DefaultDimensionSpec("dim1", "d0"), - new DefaultDimensionSpec("dim2", "d1") - ) - .setAggregatorSpecs( - aggregators( - new DoubleSumAggregatorFactory("a0", "m1"), - new CountAggregatorFactory("a1") - ) - ) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - ImmutableList.of( - new Object[]{"", "a", 3.0, 3L}, - new Object[]{"1", "a", 12.0, 3L} - ) - ); - } - @Test public void testPruneDeadAggregators() { @@ -3669,6 +3199,107 @@ public void testNullFloatFilter() ); } + /** + * This test case should be in {@link CalciteUnionQueryTest}. However, there's a bug in the test framework that + * doesn't reset framework once the merge buffers + */ + @DecoupledIgnore + @Test + public void testUnionAllSameTableThreeTimes() + { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo UNION ALL SELECT * FROM foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE1) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) + .setDimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + .setAggregatorSpecs( + aggregators( + new DoubleSumAggregatorFactory("a0", "m1"), + new CountAggregatorFactory("a1") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", "a", 3.0, 3L}, + new Object[]{"1", "a", 12.0, 3L} + ) + ); + } + + @DecoupledIgnore(mode = Modes.NOT_ENOUGH_RULES) + @Test + public void testExactCountDistinctUsingSubqueryOnUnionAllTables() + { + testQuery( + "SELECT\n" + + " SUM(cnt),\n" + + " COUNT(*)\n" + + "FROM (\n" + + " SELECT dim2, SUM(cnt) AS cnt\n" + + " FROM (SELECT * FROM druid.foo UNION ALL SELECT * FROM druid.foo)\n" + + " GROUP BY dim2\n" + + ")", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE1) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("_a0", "a0"), + new CountAggregatorFactory("_a1") + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() ? + ImmutableList.of( + new Object[]{12L, 3L} + ) : + ImmutableList.of( + new Object[]{12L, 4L} + ) + ); + } + @Test public void testNullDoubleTopN() { @@ -7330,60 +6961,6 @@ public void testExactCountDistinctUsingSubquery() ); } - @DecoupledIgnore(mode = Modes.NOT_ENOUGH_RULES) - @Test - public void testExactCountDistinctUsingSubqueryOnUnionAllTables() - { - msqIncompatible(); - testQuery( - "SELECT\n" - + " SUM(cnt),\n" - + " COUNT(*)\n" - + "FROM (\n" - + " SELECT dim2, SUM(cnt) AS cnt\n" - + " FROM (SELECT * FROM druid.foo UNION ALL SELECT * FROM druid.foo)\n" - + " GROUP BY dim2\n" - + ")", - ImmutableList.of( - GroupByQuery.builder() - .setDataSource( - new QueryDataSource( - GroupByQuery.builder() - .setDataSource( - new UnionDataSource( - ImmutableList.of( - new TableDataSource(CalciteTests.DATASOURCE1), - new TableDataSource(CalciteTests.DATASOURCE1) - ) - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) - .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ) - ) - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setAggregatorSpecs(aggregators( - new LongSumAggregatorFactory("_a0", "a0"), - new CountAggregatorFactory("_a1") - )) - .setContext(QUERY_CONTEXT_DEFAULT) - .build() - ), - NullHandling.replaceWithDefault() ? - ImmutableList.of( - new Object[]{12L, 3L} - ) : - ImmutableList.of( - new Object[]{12L, 4L} - ) - ); - } - @Test public void testAvgDailyCountDistinct() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteUnionQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteUnionQueryTest.java new file mode 100644 index 000000000000..773e1776857d --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteUnionQueryTest.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.sql.calcite.filtration.Filtration; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; +import org.junit.Test; + +public class CalciteUnionQueryTest extends BaseCalciteQueryTest +{ + @Test + public void testUnionAllDifferentTablesWithMapping() + { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM numfoo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE3) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) + .setDimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + .setAggregatorSpecs( + aggregators( + new DoubleSumAggregatorFactory("a0", "m1"), + new CountAggregatorFactory("a1") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", "a", 2.0, 2L}, + new Object[]{"1", "a", 8.0, 2L} + ) + ); + } + + @Test + public void testJoinUnionAllDifferentTablesWithMapping() + { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM numfoo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE3) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) + .setDimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + .setAggregatorSpecs( + aggregators( + new DoubleSumAggregatorFactory("a0", "m1"), + new CountAggregatorFactory("a1") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", "a", 2.0, 2L}, + new Object[]{"1", "a", 8.0, 2L} + ) + ); + } + + @Test + public void testUnionAllTablesColumnCountMismatch() + { + try { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM numfoo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of(), + ImmutableList.of() + ); + Assert.fail("query execution should fail"); + } + catch (DruidException e) { + MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [42])")); + } + } + + @Test + public void testUnionAllTablesColumnTypeMismatchFloatLong() + { + // "m1" has a different type in foo and foo2 (float vs long), but this query is OK anyway because they can both + // be implicitly cast to double. + + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT dim1, dim2, m1 FROM foo2 UNION ALL SELECT dim1, dim2, m1 FROM foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'en'\n" + + "GROUP BY 1, 2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE2), + new TableDataSource(CalciteTests.DATASOURCE1) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim2", ImmutableList.of("en", "a"), null)) + .setDimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + .setAggregatorSpecs( + aggregators( + new DoubleSumAggregatorFactory("a0", "m1"), + new CountAggregatorFactory("a1") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", "a", 1.0, 1L}, + new Object[]{"1", "a", 4.0, 1L}, + new Object[]{"druid", "en", 1.0, 1L} + ) + ); + } + + @Test + public void testUnionAllTablesColumnTypeMismatchStringLong() + { + // "dim3" has a different type in foo and foo2 (string vs long), which requires a casting subquery, so this + // query cannot be planned. + + assertQueryIsUnplannable( + "SELECT\n" + + "dim3, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT dim3, dim2, m1 FROM foo2 UNION ALL SELECT dim3, dim2, m1 FROM foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'en'\n" + + "GROUP BY 1, 2", + "SQL requires union between inputs that are not simple table scans and involve a " + + "filter or aliasing. Or column types of tables being unioned are not of same type." + ); + } + + @Test + public void testUnionAllTablesWhenMappingIsRequired() + { + // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery. + + assertQueryIsUnplannable( + "SELECT\n" + + "c, COUNT(*)\n" + + "FROM (SELECT dim1 AS c, m1 FROM foo UNION ALL SELECT dim2 AS c, m1 FROM numfoo)\n" + + "WHERE c = 'a' OR c = 'def'\n" + + "GROUP BY 1", + "SQL requires union between two tables " + + "and column names queried for each table are different Left: [dim1], Right: [dim2]." + ); + } + + @Test + public void testUnionIsUnplannable() + { + // Cannot plan this UNION operation + assertQueryIsUnplannable( + "SELECT dim2, dim1, m1 FROM foo2 UNION SELECT dim1, dim2, m1 FROM foo", + "SQL requires 'UNION' but only 'UNION ALL' is supported." + ); + } + + @Test + public void testUnionAllTablesWhenCastAndMappingIsRequired() + { + // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery. + assertQueryIsUnplannable( + "SELECT\n" + + "c, COUNT(*)\n" + + "FROM (SELECT dim1 AS c, m1 FROM foo UNION ALL SELECT cnt AS c, m1 FROM numfoo)\n" + + "WHERE c = 'a' OR c = 'def'\n" + + "GROUP BY 1", + "SQL requires union between inputs that are not simple table scans and involve " + + "a filter or aliasing. Or column types of tables being unioned are not of same type." + ); + } + + @Test + public void testUnionAllSameTableTwice() + { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE1) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) + .setDimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + .setAggregatorSpecs( + aggregators( + new DoubleSumAggregatorFactory("a0", "m1"), + new CountAggregatorFactory("a1") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", "a", 2.0, 2L}, + new Object[]{"1", "a", 8.0, 2L} + ) + ); + } + + @Test + public void testUnionAllSameTableTwiceWithSameMapping() + { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new UnionDataSource( + ImmutableList.of( + new TableDataSource(CalciteTests.DATASOURCE1), + new TableDataSource(CalciteTests.DATASOURCE1) + ) + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null)) + .setDimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + .setAggregatorSpecs( + aggregators( + new DoubleSumAggregatorFactory("a0", "m1"), + new CountAggregatorFactory("a1") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", "a", 2.0, 2L}, + new Object[]{"1", "a", 8.0, 2L} + ) + ); + } + + @Test + public void testUnionAllSameTableTwiceWithDifferentMapping() + { + // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery. + assertQueryIsUnplannable( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim2, dim1, m1 FROM foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + "SQL requires union between two tables and column names queried for each table are different Left: [dim1, dim2, m1], Right: [dim2, dim1, m1]." + ); + } + + @Test + public void testUnionAllThreeTablesColumnCountMismatch1() + { + try { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT * FROM numfoo UNION ALL SELECT * FROM foo UNION ALL SELECT * from foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of(), + ImmutableList.of() + ); + Assert.fail("query execution should fail"); + } + catch (DruidException e) { + MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [45])")); + } + } + + @Test + public void testUnionAllThreeTablesColumnCountMismatch2() + { + try { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT * FROM numfoo UNION ALL SELECT * FROM foo UNION ALL SELECT * from foo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of(), + ImmutableList.of() + ); + Assert.fail("query execution should fail"); + } + catch (DruidException e) { + MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [45])")); + } + } + + @Test + public void testUnionAllThreeTablesColumnCountMismatch3() + { + try { + testQuery( + "SELECT\n" + + "dim1, dim2, SUM(m1), COUNT(*)\n" + + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo UNION ALL SELECT * from numfoo)\n" + + "WHERE dim2 = 'a' OR dim2 = 'def'\n" + + "GROUP BY 1, 2", + ImmutableList.of(), + ImmutableList.of() + ); + Assert.fail("query execution should fail"); + } + catch (DruidException e) { + MatcherAssert.assertThat(e, invalidSqlIs("Column count mismatch in UNION ALL (line [3], column [70])")); + } + } + +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java index 46fb40fddadb..b0bf0bd7b29d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java @@ -90,6 +90,7 @@ public boolean featureAvailable(final EngineFeature feature, final PlannerContex case READ_EXTERNAL_DATA: case SCAN_ORDER_BY_NON_TIME: case ALLOW_BROADCAST_RIGHTY_JOIN: + case ALLOW_TOP_LEVEL_UNION_ALL: return true; default: throw SqlEngines.generateUnrecognizedFeatureException(IngestionTestSqlEngine.class.getSimpleName(), feature); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java index 48e7ee2423b3..12db32d4f019 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java @@ -19,6 +19,7 @@ package org.apache.druid.sql.calcite.planner; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableSet; import com.google.inject.Guice; import com.google.inject.Injector; @@ -42,7 +43,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; import org.apache.druid.sql.calcite.rule.ExtensionCalciteRuleProvider; -import org.apache.druid.sql.calcite.run.SqlEngine; +import org.apache.druid.sql.calcite.run.NativeSqlEngine; import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; import org.apache.druid.sql.calcite.schema.DruidSchemaName; import org.apache.druid.sql.calcite.schema.NamedSchema; @@ -90,8 +91,6 @@ public class CalcitePlannerModuleTest extends CalciteTestBase @Mock private DruidSchemaCatalog rootSchema; - @Mock - private SqlEngine engine; private Set aggregators; private Set operatorConversions; @@ -175,10 +174,11 @@ public void testPlannerConfigIsInjected() @Test public void testExtensionCalciteRule() { + ObjectMapper mapper = new DefaultObjectMapper(); PlannerToolbox toolbox = new PlannerToolbox( injector.getInstance(DruidOperatorTable.class), macroTable, - new DefaultObjectMapper(), + mapper, injector.getInstance(PlannerConfig.class), rootSchema, joinableFactoryWrapper, @@ -189,11 +189,10 @@ public void testExtensionCalciteRule() AuthConfig.newBuilder().build() ); - PlannerContext context = PlannerContext.create( toolbox, "SELECT 1", - engine, + new NativeSqlEngine(queryLifecycleFactory, mapper), Collections.emptyMap(), null ); From 1fc8fb1b20080ae8c515dcbaa8f5d95f87f7a63e Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 9 Oct 2023 06:16:06 -0700 Subject: [PATCH 11/14] add a bunch of tests with array typed columns to CalciteArraysQueryTest (#15101) * add a bunch of tests with array typed columns to CalciteArraysQueryTest * fix a bug with unnest filter pushdown when filtering on unnested array columns --- .../druid/msq/test/CalciteMSQTestsHelper.java | 37 +- .../druid/segment/UnnestStorageAdapter.java | 31 +- .../segment/UnnestStorageAdapterTest.java | 190 +- .../sql/calcite/CalciteArraysQueryTest.java | 1831 ++++++++++++++++- .../sql/calcite/util/TestDataBuilder.java | 2 +- 5 files changed, 1988 insertions(+), 103 deletions(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java index c68b2331c7d9..5b49c649cc0c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java @@ -27,6 +27,7 @@ import com.google.inject.TypeLiteral; import org.apache.druid.collections.ReferenceCountingResourceHolder; import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.data.input.ResourceInputSource; import org.apache.druid.data.input.impl.DimensionsSpec; import org.apache.druid.data.input.impl.LongDimensionSchema; import org.apache.druid.data.input.impl.StringDimensionSchema; @@ -47,6 +48,7 @@ import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.ForwardingQueryProcessingPool; +import org.apache.druid.query.NestedDataTestUtils; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; @@ -74,7 +76,9 @@ import org.apache.druid.server.SegmentManager; import org.apache.druid.server.coordination.DataSegmentAnnouncer; import org.apache.druid.server.coordination.NoopDataSegmentAnnouncer; +import org.apache.druid.sql.calcite.CalciteArraysQueryTest; import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.calcite.util.TestDataBuilder; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.SegmentId; import org.easymock.EasyMock; @@ -83,7 +87,6 @@ import org.mockito.Mockito; import javax.annotation.Nullable; -import java.io.File; import java.io.IOException; import java.util.List; import java.util.Set; @@ -232,7 +235,7 @@ private static Supplier> getSupplierForSegment(SegmentId .build(); index = IndexBuilder .create() - .tmpDir(new File(temporaryFolder.newFolder(), "1")) + .tmpDir(temporaryFolder.newFolder()) .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) .schema(foo1Schema) .rows(ROWS1) @@ -259,7 +262,7 @@ private static Supplier> getSupplierForSegment(SegmentId .build(); index = IndexBuilder .create() - .tmpDir(new File(temporaryFolder.newFolder(), "2")) + .tmpDir(temporaryFolder.newFolder()) .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) .schema(indexSchemaDifferentDim3M1Types) .rows(ROWS2) @@ -269,7 +272,7 @@ private static Supplier> getSupplierForSegment(SegmentId case CalciteTests.BROADCAST_DATASOURCE: index = IndexBuilder .create() - .tmpDir(new File(temporaryFolder.newFolder(), "3")) + .tmpDir(temporaryFolder.newFolder()) .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) .schema(INDEX_SCHEMA_NUMERIC_DIMS) .rows(ROWS1_WITH_NUMERIC_DIMS) @@ -278,12 +281,36 @@ private static Supplier> getSupplierForSegment(SegmentId case DATASOURCE5: index = IndexBuilder .create() - .tmpDir(new File(temporaryFolder.newFolder(), "5")) + .tmpDir(temporaryFolder.newFolder()) .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) .schema(INDEX_SCHEMA_LOTS_O_COLUMNS) .rows(ROWS_LOTS_OF_COLUMNS) .buildMMappedIndex(); break; + case CalciteArraysQueryTest.DATA_SOURCE_ARRAYS: + index = IndexBuilder.create() + .tmpDir(temporaryFolder.newFolder()) + .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) + .schema( + new IncrementalIndexSchema.Builder() + .withTimestampSpec(NestedDataTestUtils.AUTO_SCHEMA.getTimestampSpec()) + .withDimensionsSpec(NestedDataTestUtils.AUTO_SCHEMA.getDimensionsSpec()) + .withMetrics( + new CountAggregatorFactory("cnt") + ) + .withRollup(false) + .build() + ) + .inputSource( + ResourceInputSource.of( + NestedDataTestUtils.class.getClassLoader(), + NestedDataTestUtils.ARRAY_TYPES_DATA_FILE + ) + ) + .inputFormat(TestDataBuilder.DEFAULT_JSON_INPUT_FORMAT) + .inputTmpDir(temporaryFolder.newFolder()) + .buildMMappedIndex(); + break; default: throw new ISE("Cannot query segment %s in test runner", segmentId); diff --git a/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java index 02f8c0064aa2..e9839a37818c 100644 --- a/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/UnnestStorageAdapter.java @@ -322,19 +322,17 @@ to generate filters to be passed to base cursor (filtersPushedDownToBaseCursor) // outside filter contains unnested column // requires check for OR and And filters, disqualify rewrite for non-unnest filters if (queryFilter instanceof BooleanFilter) { - boolean isTopLevelAndFilter = queryFilter instanceof AndFilter; List preFilterList = recursiveRewriteOnUnnestFilters( (BooleanFilter) queryFilter, inputColumn, inputColumnCapabilites, - filterSplitter, - isTopLevelAndFilter + filterSplitter ); // If rewite on entire query filter is successful then add entire filter to preFilter else skip and only add to post filter. - if (filterSplitter.getPreFilterCount() == filterSplitter.getOriginalFilterCount()) { + if (!preFilterList.isEmpty()) { if (queryFilter instanceof AndFilter) { filterSplitter.addPreFilter(new AndFilter(preFilterList)); - } else if (queryFilter instanceof OrFilter) { + } else if (queryFilter instanceof OrFilter && filterSplitter.getPreFilterCount() == filterSplitter.getOriginalFilterCount()) { filterSplitter.addPreFilter(new OrFilter(preFilterList)); } } @@ -470,8 +468,7 @@ private List recursiveRewriteOnUnnestFilters( BooleanFilter queryFilter, final String inputColumn, final ColumnCapabilities inputColumnCapabilites, - final FilterSplitter filterSplitter, - final boolean isTopLevelAndFilter + final FilterSplitter filterSplitter ) { final List preFilterList = new ArrayList<>(); @@ -482,25 +479,26 @@ private List recursiveRewriteOnUnnestFilters( (BooleanFilter) filter, inputColumn, inputColumnCapabilites, - filterSplitter, - isTopLevelAndFilter + filterSplitter ); if (!andChildFilters.isEmpty()) { preFilterList.add(new AndFilter(andChildFilters)); } } else if (filter instanceof OrFilter) { - // in case of Or Fiters, we set isTopLevelAndFilter to false that prevents pushing down any child filters to base List orChildFilters = recursiveRewriteOnUnnestFilters( (BooleanFilter) filter, inputColumn, inputColumnCapabilites, - filterSplitter, - false + filterSplitter ); - preFilterList.add(new OrFilter(orChildFilters)); + if (orChildFilters.size() == ((OrFilter) filter).getFilters().size()) { + preFilterList.add(new OrFilter(orChildFilters)); + } } else if (filter instanceof NotFilter) { + // nothing to do here... continue; } else { + // can we rewrite final Filter newFilter = rewriteFilterOnUnnestColumnIfPossible( filter, inputColumn, @@ -511,13 +509,6 @@ private List recursiveRewriteOnUnnestFilters( preFilterList.add(newFilter); filterSplitter.addToPreFilterCount(1); } - /* - Push down the filters to base only if top level is And Filter - we can not push down if top level filter is OR or unnestColumn is derived expression like arrays - */ - if (isTopLevelAndFilter && getUnnestInputIfDirectAccess(unnestColumn) != null) { - filterSplitter.addPreFilter(newFilter != null ? newFilter : filter); - } filterSplitter.addToOriginalFilterCount(1); } } else { diff --git a/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java index 2139335b594a..286a636e89a3 100644 --- a/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/UnnestStorageAdapterTest.java @@ -20,6 +20,8 @@ package org.apache.druid.segment; import com.google.common.collect.ImmutableList; +import org.apache.druid.data.input.InputSource; +import org.apache.druid.data.input.ResourceInputSource; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.granularity.Granularities; @@ -27,11 +29,14 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.NestedDataTestUtils; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.filter.EqualityFilter; import org.apache.druid.query.filter.Filter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.filter.AndFilter; import org.apache.druid.segment.filter.OrFilter; @@ -40,8 +45,10 @@ import org.apache.druid.segment.generator.GeneratorSchemaInfo; import org.apache.druid.segment.generator.SegmentGenerator; import org.apache.druid.segment.incremental.IncrementalIndex; +import org.apache.druid.segment.incremental.IncrementalIndexSchema; import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter; import org.apache.druid.segment.join.PostJoinCursor; +import org.apache.druid.segment.transform.TransformSpec; import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.timeline.DataSegment; @@ -53,9 +60,12 @@ import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import javax.annotation.Nullable; +import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -66,18 +76,23 @@ public class UnnestStorageAdapterTest extends InitializedNullHandlingTest { + @ClassRule + public static TemporaryFolder tmp = new TemporaryFolder(); private static Closer CLOSER; private static IncrementalIndex INCREMENTAL_INDEX; private static IncrementalIndexStorageAdapter INCREMENTAL_INDEX_STORAGE_ADAPTER; + private static QueryableIndex QUERYABLE_INDEX; private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER; private static UnnestStorageAdapter UNNEST_STORAGE_ADAPTER1; + private static UnnestStorageAdapter UNNEST_ARRAYS; private static List ADAPTERS; - private static String COLUMNNAME = "multi-string1"; + private static String INPUT_COLUMN_NAME = "multi-string1"; private static String OUTPUT_COLUMN_NAME = "unnested-multi-string1"; private static String OUTPUT_COLUMN_NAME1 = "unnested-multi-string1-again"; + @BeforeClass - public static void setup() + public static void setup() throws IOException { CLOSER = Closer.create(); final GeneratorSchemaInfo schemaInfo = GeneratorBasicSchemas.SCHEMA_MAP.get("expression-testbench"); @@ -98,13 +113,40 @@ public static void setup() INCREMENTAL_INDEX_STORAGE_ADAPTER = new IncrementalIndexStorageAdapter(INCREMENTAL_INDEX); UNNEST_STORAGE_ADAPTER = new UnnestStorageAdapter( INCREMENTAL_INDEX_STORAGE_ADAPTER, - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + INPUT_COLUMN_NAME + "\"", null, ExprMacroTable.nil()), null ); UNNEST_STORAGE_ADAPTER1 = new UnnestStorageAdapter( UNNEST_STORAGE_ADAPTER, - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME1, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME1, "\"" + INPUT_COLUMN_NAME + "\"", null, ExprMacroTable.nil()), + null + ); + + final InputSource inputSource = ResourceInputSource.of( + UnnestStorageAdapterTest.class.getClassLoader(), + NestedDataTestUtils.ALL_TYPES_TEST_DATA_FILE + ); + IndexBuilder bob = IndexBuilder.create() + .tmpDir(tmp.newFolder()) + .schema( + IncrementalIndexSchema.builder() + .withTimestampSpec(NestedDataTestUtils.TIMESTAMP_SPEC) + .withDimensionsSpec(NestedDataTestUtils.AUTO_DISCOVERY) + .withQueryGranularity(Granularities.DAY) + .withRollup(false) + .withMinTimestamp(0) + .build() + ) + .indexSpec(IndexSpec.DEFAULT) + .inputSource(inputSource) + .inputFormat(NestedDataTestUtils.DEFAULT_JSON_INPUT_FORMAT) + .transform(TransformSpec.NONE) + .inputTmpDir(tmp.newFolder()); + QUERYABLE_INDEX = CLOSER.register(bob.buildMMappedIndex()); + UNNEST_ARRAYS = new UnnestStorageAdapter( + new QueryableIndexStorageAdapter(QUERYABLE_INDEX), + new ExpressionVirtualColumn("u", "\"arrayLongNulls\"", ColumnType.LONG, ExprMacroTable.nil()), null ); @@ -269,7 +311,7 @@ public void test_pushdown_or_filters_unnested_and_original_dimension_with_unnest { final UnnestStorageAdapter unnestStorageAdapter = new UnnestStorageAdapter( new TestStorageAdapter(INCREMENTAL_INDEX), - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + INPUT_COLUMN_NAME + "\"", null, ExprMacroTable.nil()), null ); @@ -313,7 +355,7 @@ public void test_nested_filters_unnested_and_original_dimension_with_unnest_adap { final UnnestStorageAdapter unnestStorageAdapter = new UnnestStorageAdapter( new TestStorageAdapter(INCREMENTAL_INDEX), - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + INPUT_COLUMN_NAME + "\"", null, ExprMacroTable.nil()), null ); @@ -365,7 +407,7 @@ public void test_nested_filters_unnested_and_topLevel1And3filtersInOR() selector(OUTPUT_COLUMN_NAME, "3"), or(ImmutableList.of( selector("newcol", "2"), - selector(COLUMNNAME, "2"), + selector(INPUT_COLUMN_NAME, "2"), selector(OUTPUT_COLUMN_NAME, "1") )) )); @@ -383,10 +425,10 @@ public void test_nested_multiLevel_filters_unnested() or(ImmutableList.of( or(ImmutableList.of( selector("newcol", "2"), - selector(COLUMNNAME, "2"), + selector(INPUT_COLUMN_NAME, "2"), and(ImmutableList.of( selector("newcol", "3"), - selector(COLUMNNAME, "7") + selector(INPUT_COLUMN_NAME, "7") )) )), selector(OUTPUT_COLUMN_NAME, "1") @@ -406,11 +448,11 @@ public void test_nested_multiLevel_filters_unnested5Level() or(ImmutableList.of( or(ImmutableList.of( selector("newcol", "2"), - selector(COLUMNNAME, "2"), + selector(INPUT_COLUMN_NAME, "2"), and(ImmutableList.of( selector("newcol", "3"), and(ImmutableList.of( - selector(COLUMNNAME, "7"), + selector(INPUT_COLUMN_NAME, "7"), selector("newcol_1", "10") )) )) @@ -431,7 +473,7 @@ public void test_nested_filters_unnested_and_topLevelORAnd3filtersInOR() selector(OUTPUT_COLUMN_NAME, "3"), and(ImmutableList.of( selector("newcol", "2"), - selector(COLUMNNAME, "2"), + selector(INPUT_COLUMN_NAME, "2"), selector(OUTPUT_COLUMN_NAME, "1") )) )); @@ -449,11 +491,11 @@ public void test_nested_filters_unnested_and_topLevelAND3filtersInORWithNestedOr selector(OUTPUT_COLUMN_NAME, "3"), or(ImmutableList.of( selector("newcol", "2"), - selector(COLUMNNAME, "2") + selector(INPUT_COLUMN_NAME, "2") )), or(ImmutableList.of( selector("newcol", "4"), - selector(COLUMNNAME, "8"), + selector(INPUT_COLUMN_NAME, "8"), selector(OUTPUT_COLUMN_NAME, "6") )) )); @@ -469,7 +511,7 @@ public void test_nested_filters_unnested_and_topLevelAND2sdf() { final Filter testQueryFilter = and(ImmutableList.of( not(selector(OUTPUT_COLUMN_NAME, "3")), - selector(COLUMNNAME, "2") + selector(INPUT_COLUMN_NAME, "2") )); testComputeBaseAndPostUnnestFilters( testQueryFilter, @@ -483,7 +525,7 @@ public void test_nested_filters_unnested_and_topLevelOR2sdf() { final Filter testQueryFilter = or(ImmutableList.of( not(selector(OUTPUT_COLUMN_NAME, "3")), - selector(COLUMNNAME, "2") + selector(INPUT_COLUMN_NAME, "2") )); testComputeBaseAndPostUnnestFilters( testQueryFilter, @@ -500,10 +542,10 @@ public void test_not_pushdown_not_filter() or(ImmutableList.of( or(ImmutableList.of( selector("newcol", "2"), - selector(COLUMNNAME, "2"), + selector(INPUT_COLUMN_NAME, "2"), and(ImmutableList.of( selector("newcol", "3"), - selector(COLUMNNAME, "7") + selector(INPUT_COLUMN_NAME, "7") )) )), selector(OUTPUT_COLUMN_NAME, "1") @@ -516,12 +558,97 @@ public void test_not_pushdown_not_filter() ); } + @Test + public void testPartialArrayPushdown() + { + final Filter testQueryFilter = and( + ImmutableList.of( + new EqualityFilter("u", ColumnType.LONG, 1L, null), + new EqualityFilter("str", ColumnType.STRING, "a", null), + new EqualityFilter("long", ColumnType.LONG, 1L, null) + ) + ); + testComputeBaseAndPostUnnestFilters( + UNNEST_ARRAYS, + testQueryFilter, + "(str = a && long = 1 (LONG))", + "(u = 1 (LONG) && str = a && long = 1 (LONG))" + ); + } + + @Test + public void testPartialArrayPushdownNested() + { + final Filter testQueryFilter = and( + ImmutableList.of( + and( + ImmutableList.of( + new EqualityFilter("u", ColumnType.LONG, 1L, null), + new EqualityFilter("str", ColumnType.STRING, "a", null) + ) + ), + new EqualityFilter("long", ColumnType.LONG, 1L, null) + ) + ); + // this seems wrong since we should be able to push down str = a and long = 1 + testComputeBaseAndPostUnnestFilters( + UNNEST_ARRAYS, + testQueryFilter, + "(str = a && long = 1 (LONG))", + "(u = 1 (LONG) && str = a && long = 1 (LONG))" + ); + } + + @Test + public void testPartialArrayPushdown2() + { + final Filter testQueryFilter = and( + ImmutableList.of( + or( + ImmutableList.of( + new EqualityFilter("u", ColumnType.LONG, 1L, null), + new EqualityFilter("str", ColumnType.STRING, "a", null) + ) + ), + new EqualityFilter("long", ColumnType.LONG, 1L, null) + ) + ); + testComputeBaseAndPostUnnestFilters( + UNNEST_ARRAYS, + testQueryFilter, + "long = 1 (LONG)", + "((u = 1 (LONG) || str = a) && long = 1 (LONG))" + ); + } + + @Test + public void testArrayCannotPushdown2() + { + final Filter testQueryFilter = or( + ImmutableList.of( + or( + ImmutableList.of( + new EqualityFilter("u", ColumnType.LONG, 1L, null), + new EqualityFilter("str", ColumnType.STRING, "a", null) + ) + ), + new EqualityFilter("long", ColumnType.LONG, 1L, null) + ) + ); + testComputeBaseAndPostUnnestFilters( + UNNEST_ARRAYS, + testQueryFilter, + "", + "(u = 1 (LONG) || str = a || long = 1 (LONG))" + ); + } + @Test public void test_pushdown_filters_unnested_dimension_with_unnest_adapters() { final UnnestStorageAdapter unnestStorageAdapter = new UnnestStorageAdapter( new TestStorageAdapter(INCREMENTAL_INDEX), - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + INPUT_COLUMN_NAME + "\"", null, ExprMacroTable.nil()), new SelectorDimFilter(OUTPUT_COLUMN_NAME, "1", null) ); @@ -567,7 +694,7 @@ public void test_pushdown_filters_unnested_dimension_outside() { final UnnestStorageAdapter unnestStorageAdapter = new UnnestStorageAdapter( new TestStorageAdapter(INCREMENTAL_INDEX), - new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + COLUMNNAME + "\"", null, ExprMacroTable.nil()), + new ExpressionVirtualColumn(OUTPUT_COLUMN_NAME, "\"" + INPUT_COLUMN_NAME + "\"", null, ExprMacroTable.nil()), null ); @@ -613,14 +740,29 @@ public void testComputeBaseAndPostUnnestFilters( String expectedPostUnnest ) { - final String inputColumn = UNNEST_STORAGE_ADAPTER.getUnnestInputIfDirectAccess(UNNEST_STORAGE_ADAPTER.getUnnestColumn()); - final VirtualColumn vc = UNNEST_STORAGE_ADAPTER.getUnnestColumn(); - Pair filterPair = UNNEST_STORAGE_ADAPTER.computeBaseAndPostUnnestFilters( + testComputeBaseAndPostUnnestFilters( + UNNEST_STORAGE_ADAPTER, + testQueryFilter, + expectedBasePushDown, + expectedPostUnnest + ); + } + + public void testComputeBaseAndPostUnnestFilters( + UnnestStorageAdapter adapter, + Filter testQueryFilter, + String expectedBasePushDown, + String expectedPostUnnest + ) + { + final String inputColumn = adapter.getUnnestInputIfDirectAccess(adapter.getUnnestColumn()); + final VirtualColumn vc = adapter.getUnnestColumn(); + Pair filterPair = adapter.computeBaseAndPostUnnestFilters( testQueryFilter, null, VirtualColumns.EMPTY, inputColumn, - vc.capabilities(UNNEST_STORAGE_ADAPTER, inputColumn) + vc.capabilities(adapter, inputColumn) ); Filter actualPushDownFilter = filterPair.lhs; Filter actualPostUnnestFilter = filterPair.rhs; diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index e6a669b9c28b..e70224274b13 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -22,19 +22,27 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.inject.Injector; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.data.input.ResourceInputSource; +import org.apache.druid.guice.DruidInjectorBuilder; +import org.apache.druid.guice.NestedDataModule; import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.DataSource; import org.apache.druid.query.Druids; import org.apache.druid.query.FilteredDataSource; +import org.apache.druid.query.FrameBasedInlineDataSource; import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.LookupDataSource; +import org.apache.druid.query.NestedDataTestUtils; import org.apache.druid.query.Query; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; +import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; @@ -54,18 +62,35 @@ import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; import org.apache.druid.query.groupby.orderby.NoopLimitSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.topn.DimensionTopNMetricSpec; import org.apache.druid.query.topn.TopNQueryBuilder; +import org.apache.druid.segment.FrameBasedInlineSegmentWrangler; +import org.apache.druid.segment.IndexBuilder; +import org.apache.druid.segment.InlineSegmentWrangler; +import org.apache.druid.segment.LookupSegmentWrangler; +import org.apache.druid.segment.MapSegmentWrangler; +import org.apache.druid.segment.QueryableIndex; +import org.apache.druid.segment.SegmentWrangler; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.incremental.IncrementalIndexSchema; import org.apache.druid.segment.join.JoinType; +import org.apache.druid.segment.join.JoinableFactoryWrapper; +import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; +import org.apache.druid.server.QueryStackTests; import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; +import org.apache.druid.sql.calcite.util.TestDataBuilder; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.partition.LinearShardSpec; import org.junit.Assert; import org.junit.Test; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -82,6 +107,154 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest .put(QueryContexts.CTX_SQL_STRINGIFY_ARRAYS, false) .build(); + + public static final String DATA_SOURCE_ARRAYS = "arrays"; + + public static void assertResultsDeepEquals(String sql, List expected, List results) + { + for (int row = 0; row < results.size(); row++) { + for (int col = 0; col < results.get(row).length; col++) { + final String rowString = StringUtils.format("result #%d: %s", row + 1, sql); + assertDeepEquals(rowString + " - column: " + col + ":", expected.get(row)[col], results.get(row)[col]); + } + } + } + + public static void assertDeepEquals(String path, Object expected, Object actual) + { + if (expected instanceof List && actual instanceof List) { + List expectedList = (List) expected; + List actualList = (List) actual; + Assert.assertEquals(path + " arrays length mismatch", expectedList.size(), actualList.size()); + for (int i = 0; i < expectedList.size(); i++) { + assertDeepEquals(path + "[" + i + "]", expectedList.get(i), actualList.get(i)); + } + } else { + Assert.assertEquals(path, expected, actual); + } + } + + @Override + public void configureGuice(DruidInjectorBuilder builder) + { + super.configureGuice(builder); + builder.addModule(new NestedDataModule()); + } + + @SuppressWarnings("resource") + @Override + public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker( + final QueryRunnerFactoryConglomerate conglomerate, + final JoinableFactoryWrapper joinableFactory, + final Injector injector + ) throws IOException + { + NestedDataModule.registerHandlersAndSerde(); + + final QueryableIndex foo = IndexBuilder + .create() + .tmpDir(temporaryFolder.newFolder()) + .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) + .schema(TestDataBuilder.INDEX_SCHEMA) + .rows(TestDataBuilder.ROWS1) + .buildMMappedIndex(); + + final QueryableIndex numfoo = IndexBuilder + .create() + .tmpDir(temporaryFolder.newFolder()) + .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) + .schema(TestDataBuilder.INDEX_SCHEMA_NUMERIC_DIMS) + .rows(TestDataBuilder.ROWS1_WITH_NUMERIC_DIMS) + .buildMMappedIndex(); + + final QueryableIndex indexLotsOfColumns = IndexBuilder + .create() + .tmpDir(temporaryFolder.newFolder()) + .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) + .schema(TestDataBuilder.INDEX_SCHEMA_LOTS_O_COLUMNS) + .rows(TestDataBuilder.ROWS_LOTS_OF_COLUMNS) + .buildMMappedIndex(); + + final QueryableIndex indexArrays = + IndexBuilder.create() + .tmpDir(temporaryFolder.newFolder()) + .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance()) + .schema( + new IncrementalIndexSchema.Builder() + .withTimestampSpec(NestedDataTestUtils.AUTO_SCHEMA.getTimestampSpec()) + .withDimensionsSpec(NestedDataTestUtils.AUTO_SCHEMA.getDimensionsSpec()) + .withMetrics( + new CountAggregatorFactory("cnt") + ) + .withRollup(false) + .build() + ) + .inputSource( + ResourceInputSource.of( + NestedDataTestUtils.class.getClassLoader(), + NestedDataTestUtils.ARRAY_TYPES_DATA_FILE + ) + ) + .inputFormat(TestDataBuilder.DEFAULT_JSON_INPUT_FORMAT) + .inputTmpDir(temporaryFolder.newFolder()) + .buildMMappedIndex(); + + SpecificSegmentsQuerySegmentWalker walker = new SpecificSegmentsQuerySegmentWalker( + conglomerate, + new MapSegmentWrangler( + ImmutableMap., SegmentWrangler>builder() + .put(InlineDataSource.class, new InlineSegmentWrangler()) + .put(FrameBasedInlineDataSource.class, new FrameBasedInlineSegmentWrangler()) + .put( + LookupDataSource.class, + new LookupSegmentWrangler(injector.getInstance(LookupExtractorFactoryContainerProvider.class)) + ) + .build() + ), + joinableFactory, + QueryStackTests.DEFAULT_NOOP_SCHEDULER + ); + walker.add( + DataSegment.builder() + .dataSource(CalciteTests.DATASOURCE1) + .interval(foo.getDataInterval()) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .size(0) + .build(), + foo + ).add( + DataSegment.builder() + .dataSource(CalciteTests.DATASOURCE3) + .interval(numfoo.getDataInterval()) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .size(0) + .build(), + numfoo + ).add( + DataSegment.builder() + .dataSource(CalciteTests.DATASOURCE5) + .interval(indexLotsOfColumns.getDataInterval()) + .version("1") + .shardSpec(new LinearShardSpec(0)) + .size(0) + .build(), + indexLotsOfColumns + ).add( + DataSegment.builder() + .dataSource(DATA_SOURCE_ARRAYS) + .version("1") + .interval(indexArrays.getDataInterval()) + .shardSpec(new LinearShardSpec(1)) + .size(0) + .build(), + indexArrays + ); + + return walker; + } + // test some query stuffs, sort of limited since no native array column types so either need to use constructor or // array aggregator @Test @@ -135,6 +308,36 @@ public void testGroupByArrayFromCase() ); } + @Test + public void testGroupByArrayColumnFromCase() + { + cannotVectorize(); + testQuery( + "SELECT CASE WHEN arrayStringNulls = ARRAY['a', 'b'] THEN arrayLongNulls END as arr, count(1) from arrays GROUP BY 1", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE_ARRAYS) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "case_searched((\"arrayStringNulls\" == array('a','b')),\"arrayLongNulls\",null)", + ColumnType.LONG_ARRAY + )) + .setDimensions(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG_ARRAY)) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{null, 11L}, + new Object[]{Arrays.asList(1L, null, 3L), 1L}, + new Object[]{Arrays.asList(2L, 3L), 2L} + ) + ); + } + @Test public void testSelectNonConstantArrayExpressionFromTable() { @@ -206,9 +409,6 @@ public void testSelectNonConstantArrayExpressionFromTableForMultival() @Test public void testSomeArrayFunctionsWithScanQuery() { - // Yes these outputs are strange sometimes, arrays are in a partial state of existence so end up a bit - // stringy for now this is because virtual column selectors are coercing values back to stringish so that - // multi-valued string dimensions can be grouped on. List expectedResults; if (useDefault) { expectedResults = ImmutableList.of( @@ -380,6 +580,136 @@ public void testSomeArrayFunctionsWithScanQuery() ); } + @Test + public void testSomeArrayFunctionsWithScanQueryArrayColumns() + { + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{null, "[]", null, null, null, "[1]", "[2]", null, null, null, "[1,2,3]", null, "", null, null, "", null, null}, + new Object[]{"[\"a\",\"b\"]", "[2,3]", "[null]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\",\"b\"]", "[2,3,1]", "[2,2,3]", "[null,1.1]", "[2.2,null]", null, null, null, "a", 2L, 0.0D, "a", 2L, 0.0D}, + new Object[]{"[\"b\",\"b\"]", "[1]", null, "[\"b\",\"b\",\"foo\"]", "[\"foo\",\"b\",\"b\"]", "[1,1]", "[2,1]", null, null, "[\"d\",\"e\",\"b\",\"b\"]", "[1,4,1]", null, "b", 1L, null, "b", 1L, null}, + new Object[]{null, "[null,2,9]", "[999.0,5.5,null]", null, null, "[null,2,9,1]", "[2,null,2,9]", "[999.0,5.5,null,1.1]", "[2.2,999.0,5.5,null]", null, null, null, "", 0L, 999.0D, "", 0L, 999.0D}, + new Object[]{"[\"a\",\"b\"]", "[1,null,3]", "[1.1,2.2,null]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\",\"b\"]", "[1,null,3,1]", "[2,1,null,3]", "[1.1,2.2,null,1.1]", "[2.2,1.1,2.2,null]", "[\"a\",\"b\",\"a\",\"b\"]", "[1,2,3,1,null,3]", "[1.1,2.2,3.3,1.1,2.2,null]", "a", 1L, 1.1D, "a", 1L, 1.1D}, + new Object[]{"[\"d\",null,\"b\"]", "[1,2,3]", "[null,2.2,null]", "[\"d\",null,\"b\",\"foo\"]", "[\"foo\",\"d\",null,\"b\"]", "[1,2,3,1]", "[2,1,2,3]", "[null,2.2,null,1.1]", "[2.2,null,2.2,null]", "[\"b\",\"c\",\"d\",null,\"b\"]", "[1,2,3,4,1,2,3]", "[1.1,3.3,null,2.2,null]", "d", 1L, 0.0D, "d", 1L, 0.0D}, + new Object[]{"[null,\"b\"]", null, "[999.0,null,5.5]", "[null,\"b\",\"foo\"]", "[\"foo\",null,\"b\"]", null, null, "[999.0,null,5.5,1.1]", "[2.2,999.0,null,5.5]", "[\"a\",\"b\",\"c\",null,\"b\"]", null, "[3.3,4.4,5.5,999.0,null,5.5]", "", null, 999.0D, "", null, 999.0D}, + new Object[]{null, null, "[]", null, null, null, null, "[1.1]", "[2.2]", null, null, "[1.1,2.2,3.3]", "", null, null, "", null, null}, + new Object[]{"[\"a\",\"b\"]", "[2,3]", "[null,1.1]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\",\"b\"]", "[2,3,1]", "[2,2,3]", "[null,1.1,1.1]", "[2.2,null,1.1]", null, null, null, "a", 2L, 0.0D, "a", 2L, 0.0D}, + new Object[]{"[\"b\",\"b\"]", "[null]", null, "[\"b\",\"b\",\"foo\"]", "[\"foo\",\"b\",\"b\"]", "[null,1]", "[2,null]", null, null, "[\"d\",\"e\",\"b\",\"b\"]", "[1,4,null]", null, "b", 0L, null, "b", 0L, null}, + new Object[]{"[null]", "[null,2,9]", "[999.0,5.5,null]", "[null,\"foo\"]", "[\"foo\",null]", "[null,2,9,1]", "[2,null,2,9]", "[999.0,5.5,null,1.1]", "[2.2,999.0,5.5,null]", "[\"a\",\"b\",null]", null, null, "", 0L, 999.0D, "", 0L, 999.0D}, + new Object[]{"[]", "[1,null,3]", "[1.1,2.2,null]", "[\"foo\"]", "[\"foo\"]", "[1,null,3,1]", "[2,1,null,3]", "[1.1,2.2,null,1.1]", "[2.2,1.1,2.2,null]", "[\"a\",\"b\"]", "[1,2,3,1,null,3]", "[1.1,2.2,3.3,1.1,2.2,null]", "", 1L, 1.1D, "", 1L, 1.1D}, + new Object[]{"[\"d\",null,\"b\"]", "[1,2,3]", "[null,2.2,null]", "[\"d\",null,\"b\",\"foo\"]", "[\"foo\",\"d\",null,\"b\"]", "[1,2,3,1]", "[2,1,2,3]", "[null,2.2,null,1.1]", "[2.2,null,2.2,null]", "[\"b\",\"c\",\"d\",null,\"b\"]", "[1,2,3,4,1,2,3]", "[1.1,3.3,null,2.2,null]", "d", 1L, 0.0D, "d", 1L, 0.0D}, + new Object[]{"[null,\"b\"]", null, "[999.0,null,5.5]", "[null,\"b\",\"foo\"]", "[\"foo\",null,\"b\"]", null, null, "[999.0,null,5.5,1.1]", "[2.2,999.0,null,5.5]", "[\"a\",\"b\",\"c\",null,\"b\"]", null, "[3.3,4.4,5.5,999.0,null,5.5]", "", null, 999.0D, "", null, 999.0D} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{null, "[]", null, null, null, "[1]", "[2]", null, null, null, "[1,2,3]", null, null, null, null, null, null, null}, + new Object[]{"[\"a\",\"b\"]", "[2,3]", "[null]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\",\"b\"]", "[2,3,1]", "[2,2,3]", "[null,1.1]", "[2.2,null]", null, null, null, "a", 2L, null, "a", 2L, null}, + new Object[]{"[\"b\",\"b\"]", "[1]", null, "[\"b\",\"b\",\"foo\"]", "[\"foo\",\"b\",\"b\"]", "[1,1]", "[2,1]", null, null, "[\"d\",\"e\",\"b\",\"b\"]", "[1,4,1]", null, "b", 1L, null, "b", 1L, null}, + new Object[]{null, "[null,2,9]", "[999.0,5.5,null]", null, null, "[null,2,9,1]", "[2,null,2,9]", "[999.0,5.5,null,1.1]", "[2.2,999.0,5.5,null]", null, null, null, null, null, 999.0D, null, null, 999.0D}, + new Object[]{"[\"a\",\"b\"]", "[1,null,3]", "[1.1,2.2,null]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\",\"b\"]", "[1,null,3,1]", "[2,1,null,3]", "[1.1,2.2,null,1.1]", "[2.2,1.1,2.2,null]", "[\"a\",\"b\",\"a\",\"b\"]", "[1,2,3,1,null,3]", "[1.1,2.2,3.3,1.1,2.2,null]", "a", 1L, 1.1D, "a", 1L, 1.1D}, + new Object[]{"[\"d\",null,\"b\"]", "[1,2,3]", "[null,2.2,null]", "[\"d\",null,\"b\",\"foo\"]", "[\"foo\",\"d\",null,\"b\"]", "[1,2,3,1]", "[2,1,2,3]", "[null,2.2,null,1.1]", "[2.2,null,2.2,null]", "[\"b\",\"c\",\"d\",null,\"b\"]", "[1,2,3,4,1,2,3]", "[1.1,3.3,null,2.2,null]", "d", 1L, null, "d", 1L, null}, + new Object[]{"[null,\"b\"]", null, "[999.0,null,5.5]", "[null,\"b\",\"foo\"]", "[\"foo\",null,\"b\"]", null, null, "[999.0,null,5.5,1.1]", "[2.2,999.0,null,5.5]", "[\"a\",\"b\",\"c\",null,\"b\"]", null, "[3.3,4.4,5.5,999.0,null,5.5]", null, null, 999.0D, null, null, 999.0D}, + new Object[]{null, null, "[]", null, null, null, null, "[1.1]", "[2.2]", null, null, "[1.1,2.2,3.3]", null, null, null, null, null, null}, + new Object[]{"[\"a\",\"b\"]", "[2,3]", "[null,1.1]", "[\"a\",\"b\",\"foo\"]", "[\"foo\",\"a\",\"b\"]", "[2,3,1]", "[2,2,3]", "[null,1.1,1.1]", "[2.2,null,1.1]", null, null, null, "a", 2L, null, "a", 2L, null}, + new Object[]{"[\"b\",\"b\"]", "[null]", null, "[\"b\",\"b\",\"foo\"]", "[\"foo\",\"b\",\"b\"]", "[null,1]", "[2,null]", null, null, "[\"d\",\"e\",\"b\",\"b\"]", "[1,4,null]", null, "b", null, null, "b", null, null}, + new Object[]{"[null]", "[null,2,9]", "[999.0,5.5,null]", "[null,\"foo\"]", "[\"foo\",null]", "[null,2,9,1]", "[2,null,2,9]", "[999.0,5.5,null,1.1]", "[2.2,999.0,5.5,null]", "[\"a\",\"b\",null]", null, null, null, null, 999.0D, null, null, 999.0D}, + new Object[]{"[]", "[1,null,3]", "[1.1,2.2,null]", "[\"foo\"]", "[\"foo\"]", "[1,null,3,1]", "[2,1,null,3]", "[1.1,2.2,null,1.1]", "[2.2,1.1,2.2,null]", "[\"a\",\"b\"]", "[1,2,3,1,null,3]", "[1.1,2.2,3.3,1.1,2.2,null]", null, 1L, 1.1D, null, 1L, 1.1D}, + new Object[]{"[\"d\",null,\"b\"]", "[1,2,3]", "[null,2.2,null]", "[\"d\",null,\"b\",\"foo\"]", "[\"foo\",\"d\",null,\"b\"]", "[1,2,3,1]", "[2,1,2,3]", "[null,2.2,null,1.1]", "[2.2,null,2.2,null]", "[\"b\",\"c\",\"d\",null,\"b\"]", "[1,2,3,4,1,2,3]", "[1.1,3.3,null,2.2,null]", "d", 1L, null, "d", 1L, null}, + new Object[]{"[null,\"b\"]", null, "[999.0,null,5.5]", "[null,\"b\",\"foo\"]", "[\"foo\",null,\"b\"]", null, null, "[999.0,null,5.5,1.1]", "[2.2,999.0,null,5.5]", "[\"a\",\"b\",\"c\",null,\"b\"]", null, "[3.3,4.4,5.5,999.0,null,5.5]", null, null, 999.0D, null, null, 999.0D} + ); + } + testQuery( + "SELECT" + + " arrayStringNulls," + + " arrayLongNulls," + + " arrayDoubleNulls," + + " ARRAY_APPEND(arrayStringNulls, 'foo')," + + " ARRAY_PREPEND('foo', arrayStringNulls)," + + " ARRAY_APPEND(arrayLongNulls, 1)," + + " ARRAY_PREPEND(2, arrayLongNulls)," + + " ARRAY_APPEND(arrayDoubleNulls, 1.1)," + + " ARRAY_PREPEND(2.2, arrayDoubleNulls)," + + " ARRAY_CONCAT(arrayString,arrayStringNulls)," + + " ARRAY_CONCAT(arrayLong,arrayLongNulls)," + + " ARRAY_CONCAT(arrayDouble,arrayDoubleNulls)," + + " ARRAY_OFFSET(arrayStringNulls,0)," + + " ARRAY_OFFSET(arrayLongNulls,0)," + + " ARRAY_OFFSET(arrayDoubleNulls,0)," + + " ARRAY_ORDINAL(arrayStringNulls,1)," + + " ARRAY_ORDINAL(arrayLongNulls,1)," + + " ARRAY_ORDINAL(arrayDoubleNulls,1)" + + " FROM druid.arrays", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + // these report as strings even though they are not, someday this will not be so + expressionVirtualColumn("v0", "array_append(\"arrayStringNulls\",'foo')", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v1", "array_prepend('foo',\"arrayStringNulls\")", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v10", "array_offset(\"arrayLongNulls\",0)", ColumnType.LONG), + expressionVirtualColumn("v11", "array_offset(\"arrayDoubleNulls\",0)", ColumnType.DOUBLE), + expressionVirtualColumn("v12", "array_ordinal(\"arrayStringNulls\",1)", ColumnType.STRING), + expressionVirtualColumn("v13", "array_ordinal(\"arrayLongNulls\",1)", ColumnType.LONG), + expressionVirtualColumn("v14", "array_ordinal(\"arrayDoubleNulls\",1)", ColumnType.DOUBLE), + expressionVirtualColumn("v2", "array_append(\"arrayLongNulls\",1)", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v3", "array_prepend(2,\"arrayLongNulls\")", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v4", "array_append(\"arrayDoubleNulls\",1.1)", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v5", "array_prepend(2.2,\"arrayDoubleNulls\")", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v6", "array_concat(\"arrayString\",\"arrayStringNulls\")", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v7", "array_concat(\"arrayLong\",\"arrayLongNulls\")", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v8", "array_concat(\"arrayDouble\",\"arrayDoubleNulls\")", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v9", "array_offset(\"arrayStringNulls\",0)", ColumnType.STRING) + ) + .columns( + "arrayDoubleNulls", + "arrayLongNulls", + "arrayStringNulls", + "v0", + "v1", + "v10", + "v11", + "v12", + "v13", + "v14", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9" + ) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + expectedResults, + RowSignature.builder() + .add("arrayStringNulls", ColumnType.STRING_ARRAY) + .add("arrayLongNulls", ColumnType.LONG_ARRAY) + .add("arrayDoubleNulls", ColumnType.DOUBLE_ARRAY) + .add("EXPR$3", ColumnType.STRING_ARRAY) + .add("EXPR$4", ColumnType.STRING_ARRAY) + .add("EXPR$5", ColumnType.LONG_ARRAY) + .add("EXPR$6", ColumnType.LONG_ARRAY) + .add("EXPR$7", ColumnType.DOUBLE_ARRAY) + .add("EXPR$8", ColumnType.DOUBLE_ARRAY) + .add("EXPR$9", ColumnType.STRING_ARRAY) + .add("EXPR$10", ColumnType.LONG_ARRAY) + .add("EXPR$11", ColumnType.DOUBLE_ARRAY) + .add("EXPR$12", ColumnType.STRING) + .add("EXPR$13", ColumnType.LONG) + .add("EXPR$14", ColumnType.DOUBLE) + .add("EXPR$15", ColumnType.STRING) + .add("EXPR$16", ColumnType.LONG) + .add("EXPR$17", ColumnType.DOUBLE) + .build() + ); + } + @Test public void testSomeArrayFunctionsWithScanQueryNoStringify() { @@ -521,6 +851,84 @@ public void testArrayOverlapFilter() ); } + @Test + public void testArrayOverlapFilterStringArrayColumn() + { + testQuery( + "SELECT arrayStringNulls FROM druid.arrays WHERE ARRAY_OVERLAP(arrayStringNulls, ARRAY['a','b']) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_overlap(\"arrayStringNulls\",array('a','b'))")) + .columns("arrayStringNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"b\"]"}, + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"d\",null,\"b\"]"}, + new Object[]{"[null,\"b\"]"} + ) + ); + } + + @Test + public void testArrayOverlapFilterLongArrayColumn() + { + testQuery( + "SELECT arrayLongNulls FROM druid.arrays WHERE ARRAY_OVERLAP(arrayLongNulls, ARRAY[1, 2]) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_overlap(\"arrayLongNulls\",array(1,2))")) + .columns("arrayLongNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[2,3]"}, + new Object[]{"[1]"}, + new Object[]{"[null,2,9]"}, + new Object[]{"[1,null,3]"}, + new Object[]{"[1,2,3]"} + ) + ); + } + + @Test + public void testArrayOverlapFilterDoubleArrayColumn() + { + testQuery( + "SELECT arrayDoubleNulls FROM druid.arrays WHERE ARRAY_OVERLAP(arrayDoubleNulls, ARRAY[1.1, 2.2]) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_overlap(\"arrayDoubleNulls\",array(1.1,2.2))")) + .columns("arrayDoubleNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[1.1,2.2,null]"}, + new Object[]{"[null,2.2,null]"}, + new Object[]{"[null,1.1]"}, + new Object[]{"[1.1,2.2,null]"}, + new Object[]{"[null,2.2,null]"} + ) + ); + } + @Test public void testArrayOverlapFilterWithExtractionFn() { @@ -570,6 +978,83 @@ public void testArrayOverlapFilterNonLiteral() ); } + @Test + public void testArrayOverlapFilterArrayStringColumns() + { + testQuery( + "SELECT arrayStringNulls, arrayString FROM druid.arrays WHERE ARRAY_OVERLAP(arrayStringNulls, arrayString) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_overlap(\"arrayStringNulls\",\"arrayString\")")) + .columns("arrayString", "arrayStringNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "[\"a\",\"b\"]"}, + new Object[]{"[\"d\",null,\"b\"]", "[\"b\",\"c\"]"}, + new Object[]{"[null,\"b\"]", "[\"a\",\"b\",\"c\"]"}, + new Object[]{"[\"d\",null,\"b\"]", "[\"b\",\"c\"]"}, + new Object[]{"[null,\"b\"]", "[\"a\",\"b\",\"c\"]"} + ) + ); + } + + @Test + public void testArrayOverlapFilterArrayLongColumns() + { + testQuery( + "SELECT arrayLongNulls, arrayLong FROM druid.arrays WHERE ARRAY_OVERLAP(arrayLongNulls, arrayLong) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_overlap(\"arrayLongNulls\",\"arrayLong\")")) + .columns("arrayLong", "arrayLongNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[1]", "[1,4]"}, + new Object[]{"[1,null,3]", "[1,2,3]"}, + new Object[]{"[1,2,3]", "[1,2,3,4]"}, + new Object[]{"[1,null,3]", "[1,2,3]"}, + new Object[]{"[1,2,3]", "[1,2,3,4]"} + ) + ); + } + + @Test + public void testArrayOverlapFilterArrayDoubleColumns() + { + testQuery( + "SELECT arrayDoubleNulls, arrayDouble FROM druid.arrays WHERE ARRAY_OVERLAP(arrayDoubleNulls, arrayDouble) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_overlap(\"arrayDoubleNulls\",\"arrayDouble\")")) + .columns("arrayDouble", "arrayDoubleNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[1.1,2.2,null]", "[1.1,2.2,3.3]"}, + new Object[]{"[999.0,null,5.5]", "[3.3,4.4,5.5]"}, + new Object[]{"[1.1,2.2,null]", "[1.1,2.2,3.3]"}, + new Object[]{"[999.0,null,5.5]", "[3.3,4.4,5.5]"} + ) + ); + } + @Test public void testArrayContainsFilter() { @@ -597,6 +1082,83 @@ public void testArrayContainsFilter() ); } + @Test + public void testArrayContainsFilterArrayStringColumn() + { + testQuery( + "SELECT arrayStringNulls FROM druid.arrays WHERE ARRAY_CONTAINS(arrayStringNulls, ARRAY['a','b']) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + expressionFilter("array_contains(\"arrayStringNulls\",array('a','b'))") + ) + .columns("arrayStringNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"a\",\"b\"]"} + ) + ); + } + + @Test + public void testArrayContainsFilterArrayLongColumn() + { + testQuery( + "SELECT arrayLongNulls FROM druid.arrays WHERE ARRAY_CONTAINS(arrayLongNulls, ARRAY[1, null]) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + expressionFilter("array_contains(\"arrayLongNulls\",array(1,null))") + ) + .columns("arrayLongNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[1,null,3]"}, + new Object[]{"[1,null,3]"} + ) + ); + } + + @Test + public void testArrayContainsFilterArrayDoubleColumn() + { + testQuery( + "SELECT arrayDoubleNulls FROM druid.arrays WHERE ARRAY_CONTAINS(arrayDoubleNulls, ARRAY[1.1, null]) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + expressionFilter("array_contains(\"arrayDoubleNulls\",array(1.1,null))") + ) + .columns("arrayDoubleNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[1.1,2.2,null]"}, + new Object[]{"[null,1.1]"}, + new Object[]{"[1.1,2.2,null]"} + ) + ); + } + @Test public void testArrayContainsFilterWithExtractionFn() { @@ -627,47 +1189,120 @@ public void testArrayContainsFilterWithExtractionFn() } @Test - public void testArrayContainsArrayOfOneElement() + public void testArrayContainsArrayOfOneElement() + { + testQuery( + "SELECT dim3 FROM druid.numfoo WHERE ARRAY_CONTAINS(dim3, ARRAY['a']) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(equality("dim3", "a", ColumnType.STRING)) + .columns("dim3") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"} + ) + ); + } + + @Test + public void testArrayContainsArrayOfNonLiteral() + { + testQuery( + "SELECT dim3 FROM druid.numfoo WHERE ARRAY_CONTAINS(dim3, ARRAY[dim2]) LIMIT 5", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(expressionFilter("array_contains(\"dim3\",array(\"dim2\"))")) + .columns("dim3") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"} + ) + ); + } + + @Test + public void testArrayContainsFilterArrayStringColumns() + { + testQuery( + "SELECT arrayStringNulls, arrayString FROM druid.arrays WHERE ARRAY_CONTAINS(arrayStringNulls, arrayString) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + expressionFilter("array_contains(\"arrayStringNulls\",\"arrayString\")") + ) + .columns("arrayString", "arrayStringNulls") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]", "[\"a\",\"b\"]"} + ) + ); + } + + @Test + public void testArrayContainsFilterArrayLongColumns() { testQuery( - "SELECT dim3 FROM druid.numfoo WHERE ARRAY_CONTAINS(dim3, ARRAY['a']) LIMIT 5", + "SELECT arrayLong, arrayLongNulls FROM druid.arrays WHERE ARRAY_CONTAINS(arrayLong, arrayLongNulls) LIMIT 5", ImmutableList.of( newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) + .dataSource(DATA_SOURCE_ARRAYS) .intervals(querySegmentSpec(Filtration.eternity())) - .filters(equality("dim3", "a", ColumnType.STRING)) - .columns("dim3") + .filters( + expressionFilter("array_contains(\"arrayLong\",\"arrayLongNulls\")") + ) + .columns("arrayLong", "arrayLongNulls") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) .context(QUERY_CONTEXT_DEFAULT) .build() ), ImmutableList.of( - new Object[]{"[\"a\",\"b\"]"} + new Object[]{"[1,2,3]", "[]"}, + new Object[]{"[1,4]", "[1]"}, + new Object[]{"[1,2,3,4]", "[1,2,3]"}, + new Object[]{"[1,2,3,4]", "[1,2,3]"} ) ); } @Test - public void testArrayContainsArrayOfNonLiteral() + public void testArrayContainsFilterArrayDoubleColumns() { testQuery( - "SELECT dim3 FROM druid.numfoo WHERE ARRAY_CONTAINS(dim3, ARRAY[dim2]) LIMIT 5", - QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + "SELECT arrayDoubleNulls, arrayDouble FROM druid.arrays WHERE ARRAY_CONTAINS(arrayDoubleNulls, arrayDouble) LIMIT 5", ImmutableList.of( newScanQueryBuilder() - .dataSource(CalciteTests.DATASOURCE3) + .dataSource(DATA_SOURCE_ARRAYS) .intervals(querySegmentSpec(Filtration.eternity())) - .filters(expressionFilter("array_contains(\"dim3\",array(\"dim2\"))")) - .columns("dim3") + .filters( + expressionFilter("array_contains(\"arrayDoubleNulls\",\"arrayDouble\")") + ) + .columns("arrayDouble", "arrayDoubleNulls") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) - .context(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .context(QUERY_CONTEXT_DEFAULT) .build() ), - ImmutableList.of( - new Object[]{"[\"a\",\"b\"]"} - ) + ImmutableList.of() ); } @@ -699,6 +1334,46 @@ public void testArraySlice() ); } + @Test + public void testArraySliceArrayColumns() + { + testQuery( + "SELECT ARRAY_SLICE(arrayString, 1), ARRAY_SLICE(arrayLong, 2), ARRAY_SLICE(arrayDoubleNulls, 1) FROM druid.arrays", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + new Druids.ScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn("v0", "array_slice(\"arrayString\",1)", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v1", "array_slice(\"arrayLong\",2)", ColumnType.LONG_ARRAY), + expressionVirtualColumn("v2", "array_slice(\"arrayDoubleNulls\",1)", ColumnType.DOUBLE_ARRAY) + ) + .columns("v0", "v1", "v2") + .context(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .build() + ), + ImmutableList.of( + new Object[]{null, Collections.singletonList(3L), null}, + new Object[]{null, null, Collections.emptyList()}, + new Object[]{ImmutableList.of("e"), Collections.emptyList(), null}, + new Object[]{ImmutableList.of("b"), null, Arrays.asList(5.5D, null)}, + new Object[]{ImmutableList.of("b"), Collections.singletonList(3L), Arrays.asList(2.2D, null)}, + new Object[]{ImmutableList.of("c"), Arrays.asList(3L, 4L), Arrays.asList(2.2D, null)}, + new Object[]{ImmutableList.of("b", "c"), Collections.emptyList(), Arrays.asList(null, 5.5D)}, + new Object[]{null, Collections.singletonList(3L), null}, + new Object[]{null, null, Collections.singletonList(1.1D)}, + new Object[]{ImmutableList.of("e"), Collections.emptyList(), null}, + new Object[]{ImmutableList.of("b"), null, Arrays.asList(5.5D, null)}, + new Object[]{ImmutableList.of("b"), Collections.singletonList(3L), Arrays.asList(2.2D, null)}, + new Object[]{ImmutableList.of("c"), Arrays.asList(3L, 4L), Arrays.asList(2.2D, null)}, + new Object[]{ImmutableList.of("b", "c"), Collections.emptyList(), Arrays.asList(null, 5.5D)} + ) + ); + } + @Test public void testArrayLength() { @@ -742,6 +1417,64 @@ public void testArrayLength() ); } + @Test + public void testArrayLengthArrayColumn() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testQuery( + "SELECT arrayStringNulls, ARRAY_LENGTH(arrayStringNulls), SUM(cnt) FROM druid.arrays GROUP BY 1, 2 ORDER BY 2 DESC", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE_ARRAYS) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns(expressionVirtualColumn("v0", "array_length(\"arrayStringNulls\")", ColumnType.LONG)) + .setDimensions( + dimensions( + new DefaultDimensionSpec("arrayStringNulls", "d0", ColumnType.STRING_ARRAY), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "d1", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"[\"d\",null,\"b\"]", 3, 2L}, + new Object[]{"[null,\"b\"]", 2, 2L}, + new Object[]{"[\"a\",\"b\"]", 2, 3L}, + new Object[]{"[\"b\",\"b\"]", 2, 2L}, + new Object[]{"[null]", 1, 1L}, + new Object[]{"[]", 0, 1L}, + new Object[]{null, null, 3L} + ) + : ImmutableList.of( + new Object[]{"[\"d\",null,\"b\"]", 3, 2L}, + new Object[]{"[null,\"b\"]", 2, 2L}, + new Object[]{"[\"a\",\"b\"]", 2, 3L}, + new Object[]{"[\"b\",\"b\"]", 2, 2L}, + new Object[]{"[null]", 1, 1L}, + new Object[]{null, 0, 3L}, + new Object[]{"[]", 0, 1L} + ) + ); + } + @Test public void testArrayAppend() { @@ -1064,6 +1797,53 @@ public void testArrayGroupAsLongArray() ); } + @Test + public void testArrayGroupAsLongArrayColumn() + { + // Cannot vectorize as we donot have support in native query subsytem for grouping on arrays + cannotVectorize(); + testQuery( + "SELECT arrayLongNulls, SUM(cnt) FROM druid.arrays GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE_ARRAYS) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + dimensions( + new DefaultDimensionSpec("arrayLongNulls", "d0", ColumnType.LONG_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + ImmutableList.of( + new Object[]{null, 3L}, + new Object[]{Arrays.asList(null, 2L, 9L), 2L}, + new Object[]{Arrays.asList(1L, null, 3L), 2L}, + new Object[]{Arrays.asList(1L, 2L, 3L), 2L}, + new Object[]{Arrays.asList(2L, 3L), 2L}, + new Object[]{Collections.emptyList(), 1L}, + new Object[]{Collections.singletonList(null), 1L}, + new Object[]{Collections.singletonList(1L), 1L} + ) + ); + } + @Test public void testArrayGroupAsDoubleArray() @@ -1114,6 +1894,53 @@ public void testArrayGroupAsDoubleArray() ); } + @Test + public void testArrayGroupAsDoubleArrayColumn() + { + // Cannot vectorize as we donot have support in native query subsytem for grouping on arrays + cannotVectorize(); + testQuery( + "SELECT arrayDoubleNulls, SUM(cnt) FROM druid.arrays GROUP BY 1 ORDER BY 2 DESC", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE_ARRAYS) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + dimensions( + new DefaultDimensionSpec("arrayDoubleNulls", "d0", ColumnType.DOUBLE_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + ImmutableList.of( + new Object[]{null, 3L}, + new Object[]{Arrays.asList(null, 2.2D, null), 2L}, + new Object[]{Arrays.asList(1.1D, 2.2D, null), 2L}, + new Object[]{Arrays.asList(999.0D, null, 5.5D), 2L}, + new Object[]{Arrays.asList(999.0D, 5.5D, null), 2L}, + new Object[]{Collections.emptyList(), 1L}, + new Object[]{Collections.singletonList(null), 1L}, + new Object[]{Arrays.asList(null, 1.1D), 1L} + ) + ); + } + @Test public void testArrayGroupAsFloatArray() { @@ -1943,6 +2770,177 @@ public void testArrayConcatAggArrays() ); } + + @Test + public void testArrayAggArrayColumns() + { + msqIncompatible(); + // nested array party + cannotVectorize(); + if (NullHandling.replaceWithDefault()) { + // default value mode plans to selector filters for equality, which do not support array filtering + return; + } + testQuery( + "SELECT ARRAY_AGG(arrayLongNulls), ARRAY_AGG(DISTINCT arrayDouble), ARRAY_AGG(DISTINCT arrayStringNulls) FILTER(WHERE arrayLong = ARRAY[2,3]) FROM arrays WHERE arrayDoubleNulls is not null", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(notNull("arrayDoubleNulls")) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("arrayLongNulls"), + "__acc", + "ARRAY>[]", + "ARRAY>[]", + true, + true, + false, + "array_append(\"__acc\", \"arrayLongNulls\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("arrayDouble"), + "__acc", + "ARRAY>[]", + "ARRAY>[]", + true, + true, + false, + "array_set_add(\"__acc\", \"arrayDouble\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("arrayStringNulls"), + "__acc", + "ARRAY>[]", + "ARRAY>[]", + true, + true, + false, + "array_set_add(\"__acc\", \"arrayStringNulls\")", + "array_set_add_all(\"__acc\", \"a2\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + equality("arrayLong", ImmutableList.of(2, 3), ColumnType.LONG_ARRAY) + ) + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{ + "[[2,3],[null,2,9],[1,null,3],[1,2,3],null,null,[2,3],[null,2,9],[1,null,3],[1,2,3],null]", + "[null,[1.1,2.2,3.3],[1.1,3.3],[3.3,4.4,5.5]]", + "[[null,\"b\"]]" + } + ) + ); + } + + @Test + public void testArrayConcatAggArrayColumns() + { + cannotVectorize(); + if (NullHandling.replaceWithDefault()) { + // default value mode plans to selector filters for equality, which do not support array filtering + return; + } + testQuery( + "SELECT ARRAY_CONCAT_AGG(arrayLongNulls), ARRAY_CONCAT_AGG(DISTINCT arrayDouble), ARRAY_CONCAT_AGG(DISTINCT arrayStringNulls) FILTER(WHERE arrayLong = ARRAY[2,3]) FROM arrays WHERE arrayDoubleNulls is not null", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(notNull("arrayDoubleNulls")) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("arrayLongNulls"), + "__acc", + "ARRAY[]", + "ARRAY[]", + true, + false, + false, + "array_concat(\"__acc\", \"arrayLongNulls\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("arrayDouble"), + "__acc", + "ARRAY[]", + "ARRAY[]", + true, + false, + false, + "array_set_add_all(\"__acc\", \"arrayDouble\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("arrayStringNulls"), + "__acc", + "ARRAY[]", + "ARRAY[]", + true, + false, + false, + "array_set_add_all(\"__acc\", \"arrayStringNulls\")", + "array_set_add_all(\"__acc\", \"a2\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + equality("arrayLong", ImmutableList.of(2, 3), ColumnType.LONG_ARRAY) + ) + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{ + "[2,3,null,2,9,1,null,3,1,2,3,2,3,null,2,9,1,null,3,1,2,3]", + "[1.1,2.2,3.3,4.4,5.5]", + "[null,\"b\"]" + } + ) + ); + } + @Test public void testArrayAggToString() { @@ -2601,30 +3599,6 @@ public void testArrayAggGroupByArrayContainsSubquery() } - public static void assertResultsDeepEquals(String sql, List expected, List results) - { - for (int row = 0; row < results.size(); row++) { - for (int col = 0; col < results.get(row).length; col++) { - final String rowString = StringUtils.format("result #%d: %s", row + 1, sql); - assertDeepEquals(rowString + " - column: " + col + ":", expected.get(row)[col], results.get(row)[col]); - } - } - } - - public static void assertDeepEquals(String path, Object expected, Object actual) - { - if (expected instanceof List && actual instanceof List) { - List expectedList = (List) expected; - List actualList = (List) actual; - Assert.assertEquals(path + " arrays length mismatch", expectedList.size(), actualList.size()); - for (int i = 0; i < expectedList.size(); i++) { - assertDeepEquals(path + "[" + i + "]", expectedList.get(i), actualList.get(i)); - } - } else { - Assert.assertEquals(path, expected, actual); - } - } - @Test public void testUnnestInline() { @@ -2743,6 +3717,312 @@ public void testUnnest() ); } + @Test + public void testUnnestArrayColumnsString() + { + cannotVectorize(); + testQuery( + "SELECT a FROM druid.arrays, UNNEST(arrayString) as unnested (a)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayString\"", ColumnType.STRING_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{"d"}, + new Object[]{"e"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{"c"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"c"}, + new Object[]{"d"}, + new Object[]{"e"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{"c"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"c"} + ) + ); + } + + @Test + public void testUnnestArrayColumnsStringNulls() + { + cannotVectorize(); + testQuery( + "SELECT a FROM druid.arrays, UNNEST(arrayStringNulls) as unnested (a)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayStringNulls\"", ColumnType.STRING_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"d"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"b"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"b"}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"d"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"b"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"b"} + ) + ); + } + + @Test + public void testUnnestArrayColumnsLong() + { + cannotVectorize(); + testQuery( + "SELECT a FROM druid.arrays, UNNEST(arrayLong) as unnested (a)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayLong\"", ColumnType.LONG_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{4L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{4L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{4L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{4L}, + new Object[]{2L}, + new Object[]{3L} + ) + ); + } + + @Test + public void testUnnestArrayColumnsLongNulls() + { + cannotVectorize(); + testQuery( + "SELECT a FROM druid.arrays, UNNEST(arrayLongNulls) as unnested (a)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayLongNulls\"", ColumnType.LONG_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{2L}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{null}, + new Object[]{2L}, + new Object[]{9L}, + new Object[]{1L}, + new Object[]{null}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{2L}, + new Object[]{3L}, + new Object[]{null}, + new Object[]{null}, + new Object[]{2L}, + new Object[]{9L}, + new Object[]{1L}, + new Object[]{null}, + new Object[]{3L}, + new Object[]{1L}, + new Object[]{2L}, + new Object[]{3L} + ) + ); + } + + @Test + public void testUnnestArrayColumnsDouble() + { + cannotVectorize(); + testQuery( + "SELECT a FROM druid.arrays, UNNEST(arrayDouble) as unnested (a)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayDouble\"", ColumnType.DOUBLE_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{1.1D}, + new Object[]{2.2D}, + new Object[]{3.3D}, + new Object[]{2.2D}, + new Object[]{3.3D}, + new Object[]{4.0D}, + new Object[]{1.1D}, + new Object[]{2.2D}, + new Object[]{3.3D}, + new Object[]{1.1D}, + new Object[]{3.3D}, + new Object[]{3.3D}, + new Object[]{4.4D}, + new Object[]{5.5D}, + new Object[]{1.1D}, + new Object[]{2.2D}, + new Object[]{3.3D}, + new Object[]{2.2D}, + new Object[]{3.3D}, + new Object[]{4.0D}, + new Object[]{1.1D}, + new Object[]{2.2D}, + new Object[]{3.3D}, + new Object[]{1.1D}, + new Object[]{3.3D}, + new Object[]{3.3D}, + new Object[]{4.4D}, + new Object[]{5.5D} + ) + ); + } + + @Test + public void testUnnestArrayColumnsDoubleNulls() + { + cannotVectorize(); + testQuery( + "SELECT a FROM druid.arrays, UNNEST(arrayDoubleNulls) as unnested (a)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayDoubleNulls\"", ColumnType.DOUBLE_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{null}, + new Object[]{999.0D}, + new Object[]{5.5D}, + new Object[]{null}, + new Object[]{1.1D}, + new Object[]{2.2D}, + new Object[]{null}, + new Object[]{null}, + new Object[]{2.2D}, + new Object[]{null}, + new Object[]{999.0D}, + new Object[]{null}, + new Object[]{5.5D}, + new Object[]{null}, + new Object[]{1.1D}, + new Object[]{999.0D}, + new Object[]{5.5D}, + new Object[]{null}, + new Object[]{1.1D}, + new Object[]{2.2D}, + new Object[]{null}, + new Object[]{null}, + new Object[]{2.2D}, + new Object[]{null}, + new Object[]{999.0D}, + new Object[]{null}, + new Object[]{5.5D} + ) + ); + } + @Test public void testUnnestTwice() { @@ -2806,16 +4086,99 @@ public void testUnnestTwice() new Object[]{"abc", null, ImmutableList.of("abc"), "abc", NullHandling.defaultStringValue()} ) : ImmutableList.of( - new Object[]{"", ImmutableList.of("a", "b"), ImmutableList.of(""), "", "a"}, - new Object[]{"", ImmutableList.of("a", "b"), ImmutableList.of(""), "", "b"}, - new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "10", "b"}, - new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "10", "c"}, - new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "1", "b"}, - new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "1", "c"}, - new Object[]{"2", ImmutableList.of("d"), ImmutableList.of("2"), "2", "d"}, - new Object[]{"1", ImmutableList.of(""), ImmutableList.of("1"), "1", ""}, - new Object[]{"def", null, ImmutableList.of("def"), "def", null}, - new Object[]{"abc", null, ImmutableList.of("abc"), "abc", null} + new Object[]{"", ImmutableList.of("a", "b"), ImmutableList.of(""), "", "a"}, + new Object[]{"", ImmutableList.of("a", "b"), ImmutableList.of(""), "", "b"}, + new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "10", "b"}, + new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "10", "c"}, + new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "1", "b"}, + new Object[]{"10.1", ImmutableList.of("b", "c"), ImmutableList.of("10", "1"), "1", "c"}, + new Object[]{"2", ImmutableList.of("d"), ImmutableList.of("2"), "2", "d"}, + new Object[]{"1", ImmutableList.of(""), ImmutableList.of("1"), "1", ""}, + new Object[]{"def", null, ImmutableList.of("def"), "def", null}, + new Object[]{"abc", null, ImmutableList.of("abc"), "abc", null} + ) + ); + } + + @Test + public void testUnnestTwiceArrayColumns() + { + cannotVectorize(); + testQuery( + "SELECT arrayStringNulls, arrayLongNulls, usn, uln" + + " FROM\n" + + " druid.arrays,\n" + + " UNNEST(arrayStringNulls) as t2 (usn),\n" + + " UNNEST(arrayLongNulls) as t3 (uln)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + UnnestDataSource.create( + UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn( + "j0.unnest", + "\"arrayStringNulls\"", + ColumnType.STRING_ARRAY + ), + null + ), + expressionVirtualColumn( + "_j0.unnest", + "\"arrayLongNulls\"", + ColumnType.LONG_ARRAY + ), + null + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("_j0.unnest", "arrayLongNulls", "arrayStringNulls", "j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "a", 2L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "a", 3L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "b", 2L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "b", 3L}, + new Object[]{Arrays.asList("b", "b"), Collections.singletonList(1L), "b", 1L}, + new Object[]{Arrays.asList("b", "b"), Collections.singletonList(1L), "b", 1L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(1L, null, 3L), "a", 1L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(1L, null, 3L), "a", null}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(1L, null, 3L), "a", 3L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(1L, null, 3L), "b", 1L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(1L, null, 3L), "b", null}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(1L, null, 3L), "b", 3L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "d", 1L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "d", 2L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "d", 3L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), NullHandling.defaultStringValue(), 1L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), NullHandling.defaultStringValue(), 2L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), NullHandling.defaultStringValue(), 3L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "b", 1L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "b", 2L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "b", 3L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "a", 2L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "a", 3L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "b", 2L}, + new Object[]{Arrays.asList("a", "b"), Arrays.asList(2L, 3L), "b", 3L}, + new Object[]{Arrays.asList("b", "b"), Collections.singletonList(null), "b", null}, + new Object[]{Arrays.asList("b", "b"), Collections.singletonList(null), "b", null}, + new Object[]{Collections.singletonList(null), Arrays.asList(null, 2L, 9L), NullHandling.defaultStringValue(), null}, + new Object[]{Collections.singletonList(null), Arrays.asList(null, 2L, 9L), NullHandling.defaultStringValue(), 2L}, + new Object[]{Collections.singletonList(null), Arrays.asList(null, 2L, 9L), NullHandling.defaultStringValue(), 9L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "d", 1L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "d", 2L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "d", 3L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), NullHandling.defaultStringValue(), 1L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), NullHandling.defaultStringValue(), 2L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), NullHandling.defaultStringValue(), 3L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "b", 1L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "b", 2L}, + new Object[]{Arrays.asList("d", null, "b"), Arrays.asList(1L, 2L, 3L), "b", 3L} ) ); } @@ -2885,6 +4248,7 @@ public void testUnnestTwiceWithFiltersAndExpressions() ); } + @Test public void testUnnestThriceWithFiltersOnDimAndUnnestCol() { @@ -3052,6 +4416,74 @@ public void testUnnestThriceWithFiltersOnDimAndAllUnnestColumns() ); } + @Test + public void testUnnestThriceWithFiltersOnDimAndAllUnnestColumnsArrayColumns() + { + cannotVectorize(); + String sql = " SELECT arrayString, uln, udn, usn FROM \n" + + " ( SELECT * FROM \n" + + " ( SELECT * FROM arrays, UNNEST(arrayLongNulls) as ut(uln))" + + " ,UNNEST(arrayDoubleNulls) as ut(udn) \n" + + " ), UNNEST(arrayStringNulls) as ut(usn) " + + " WHERE arrayString = ARRAY['a','b'] AND uln = 1 AND udn = 2.2 AND usn = 'a'"; + List> expectedQuerySc = ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + UnnestDataSource.create( + UnnestDataSource.create( + FilteredDataSource.create( + UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn( + "j0.unnest", + "\"arrayLongNulls\"", + ColumnType.LONG_ARRAY + ), + null + ), + and( + NullHandling.sqlCompatible() + ? equality("arrayString", ImmutableList.of("a", "b"), ColumnType.STRING_ARRAY) + : expressionFilter("(\"arrayString\" == array('a','b'))"), + equality("j0.unnest", 1, ColumnType.LONG) + ) + ), + expressionVirtualColumn( + "_j0.unnest", + "\"arrayDoubleNulls\"", + ColumnType.DOUBLE_ARRAY + ), + equality("_j0.unnest", 2.2, ColumnType.DOUBLE) + ), + expressionVirtualColumn( + "__j0.unnest", + "\"arrayStringNulls\"", + ColumnType.STRING_ARRAY + ), + equality("__j0.unnest", "a", ColumnType.STRING) + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .virtualColumns( + expressionVirtualColumn("v0", "array('a','b')", ColumnType.STRING_ARRAY), + expressionVirtualColumn("v1", "1", ColumnType.LONG) + ) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("__j0.unnest", "_j0.unnest", "v0", "v1")) + .build() + ); + testQuery( + sql, + QUERY_CONTEXT_UNNEST, + expectedQuerySc, + ImmutableList.of( + new Object[]{ImmutableList.of("a", "b"), 1L, 2.2D, "a"} + ) + ); + } + @Test public void testUnnestThriceWithFiltersOnDimAndUnnestColumnsORCombinations() { @@ -3132,6 +4564,81 @@ public void testUnnestThriceWithFiltersOnDimAndUnnestColumnsORCombinations() ) ); } + + @Test + public void testUnnestThriceWithFiltersOnDimAndAllUnnestColumnsArrayColumnsOrFilters() + { + cannotVectorize(); + String sql = " SELECT arrayString, uln, udn, usn FROM \n" + + " ( SELECT * FROM \n" + + " ( SELECT * FROM arrays, UNNEST(arrayLongNulls) as ut(uln))" + + " ,UNNEST(arrayDoubleNulls) as ut(udn) \n" + + " ), UNNEST(arrayStringNulls) as ut(usn) " + + " WHERE arrayString = ARRAY['a','b'] AND (uln = 1 OR udn = 2.2) AND usn = 'a'"; + List> expectedQuerySc = ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + UnnestDataSource.create( + FilteredDataSource.create( + UnnestDataSource.create( + FilteredDataSource.create( + UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn( + "j0.unnest", + "\"arrayLongNulls\"", + ColumnType.LONG_ARRAY + ), + null + ), + NullHandling.sqlCompatible() + ? equality("arrayString", ImmutableList.of("a", "b"), ColumnType.STRING_ARRAY) + : expressionFilter("(\"arrayString\" == array('a','b'))") + ), + expressionVirtualColumn( + "_j0.unnest", + "\"arrayDoubleNulls\"", + ColumnType.DOUBLE_ARRAY + ), + null + ), + or( + equality("j0.unnest", 1, ColumnType.LONG), + equality("_j0.unnest", 2.2, ColumnType.DOUBLE) + ) + ), + expressionVirtualColumn( + "__j0.unnest", + "\"arrayStringNulls\"", + ColumnType.STRING_ARRAY + ), + equality("__j0.unnest", "a", ColumnType.STRING) + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .virtualColumns( + expressionVirtualColumn("v0", "array('a','b')", ColumnType.STRING_ARRAY) + ) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("__j0.unnest", "_j0.unnest", "j0.unnest", "v0")) + .build() + ); + testQuery( + sql, + QUERY_CONTEXT_UNNEST, + expectedQuerySc, + ImmutableList.of( + new Object[]{ImmutableList.of("a", "b"), 1L, 1.1D, "a"}, + new Object[]{ImmutableList.of("a", "b"), 1L, 2.2D, "a"}, + new Object[]{ImmutableList.of("a", "b"), 1L, null, "a"}, + new Object[]{ImmutableList.of("a", "b"), null, 2.2D, "a"}, + new Object[]{ImmutableList.of("a", "b"), 3L, 2.2D, "a"} + ) + ); + } + @Test public void testUnnestWithGroupBy() { @@ -3177,6 +4684,36 @@ public void testUnnestWithGroupBy() ); } + @Test + public void testUnnestWithGroupByArrayColumn() + { + cannotVectorize(); + testQuery( + "SELECT usn FROM druid.arrays, UNNEST(arrayStringNulls) as u (usn) GROUP BY usn ", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayStringNulls\"", ColumnType.STRING_ARRAY), + null + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setContext(QUERY_CONTEXT_UNNEST) + .setDimensions(new DefaultDimensionSpec("j0.unnest", "d0", ColumnType.STRING)) + .setGranularity(Granularities.ALL) + .setContext(QUERY_CONTEXT_UNNEST) + .build() + ), + ImmutableList.of( + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"d"} + ) + ); + } + @Test public void testUnnestWithGroupByOrderBy() { @@ -4752,6 +6289,32 @@ public void testUnnestWithSumOnUnnestedColumn() ); } + @Test + public void testUnnestWithSumOnUnnestedArrayColumn() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "select sum(c) col from druid.arrays, unnest(arrayDoubleNulls) as u(c)", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayDoubleNulls\"", ColumnType.DOUBLE_ARRAY), + null + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(QUERY_CONTEXT_UNNEST) + .aggregators(aggregators(new DoubleSumAggregatorFactory("a0", "j0.unnest"))) + .build() + ), + ImmutableList.of( + new Object[]{4030.0999999999995} + ) + ); + } + @Test public void testUnnestWithGroupByHavingWithWhereOnAggCol() { @@ -4813,6 +6376,79 @@ public void testUnnestWithGroupByHavingWithWhereOnUnnestCol() ); } + @Test + public void testUnnestWithGroupByWithWhereOnUnnestArrayCol() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT uln, COUNT(*) FROM druid.arrays, UNNEST(arrayLongNulls) AS unnested(uln) WHERE uln IN (1, 2, 3) GROUP BY uln", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayLongNulls\"", ColumnType.LONG_ARRAY), + NullHandling.sqlCompatible() + ? or( + equality("j0.unnest", 1L, ColumnType.LONG), + equality("j0.unnest", 2L, ColumnType.LONG), + equality("j0.unnest", 3L, ColumnType.LONG) + ) + : in("j0.unnest", ImmutableList.of("1", "2", "3"), null) + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setContext(QUERY_CONTEXT_UNNEST) + .setDimensions(new DefaultDimensionSpec("j0.unnest", "d0", ColumnType.LONG)) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setContext(QUERY_CONTEXT_UNNEST) + .build() + ), + ImmutableList.of( + new Object[]{1L, 5L}, + new Object[]{2L, 6L}, + new Object[]{3L, 6L} + ) + ); + } + + @Test + public void testUnnestWithGroupByHavingWithWhereOnUnnestArrayCol() + { + skipVectorize(); + cannotVectorize(); + testQuery( + "SELECT uln, COUNT(*) FROM druid.arrays, UNNEST(arrayLongNulls) AS unnested(uln) WHERE uln IN (1, 2, 3) GROUP BY uln HAVING uln=1", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(UnnestDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + expressionVirtualColumn("j0.unnest", "\"arrayLongNulls\"", ColumnType.LONG_ARRAY), + NullHandling.sqlCompatible() + ? or( + equality("j0.unnest", 1L, ColumnType.LONG), + equality("j0.unnest", 2L, ColumnType.LONG), + equality("j0.unnest", 3L, ColumnType.LONG) + ) + : in("j0.unnest", ImmutableList.of("1", "2", "3"), null) + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setContext(QUERY_CONTEXT_UNNEST) + .setDimensions(new DefaultDimensionSpec("j0.unnest", "d0", ColumnType.LONG)) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .setDimFilter(equality("j0.unnest", 1L, ColumnType.LONG)) + .setContext(QUERY_CONTEXT_UNNEST) + .build() + ), + ImmutableList.of( + new Object[]{1L, 5L} + ) + ); + } + @Test public void testUnnestVirtualWithColumnsAndNullIf() { @@ -4893,6 +6529,45 @@ public void testUnnestWithTimeFilterOnly() ); } + @Test + public void testUnnestWithTimeFilterOnlyArrayColumn() + { + testQuery( + "select c from arrays, unnest(arrayStringNulls) as u(c)" + + " where __time >= TIMESTAMP '2023-01-02 00:00:00' and __time <= TIMESTAMP '2023-01-03 00:10:00'", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource(UnnestDataSource.create( + FilteredDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + range("__time", ColumnType.LONG, 1672617600000L, 1672704600000L, false, false) + ), + expressionVirtualColumn("j0.unnest", "\"arrayStringNulls\"", ColumnType.STRING_ARRAY), + null + )) + .intervals(querySegmentSpec(Intervals.of("2023-01-02T00:00:00.000Z/2023-01-03T00:10:00.001Z"))) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{"a"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{"b"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"d"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"b"}, + new Object[]{NullHandling.defaultStringValue()}, + new Object[]{"b"} + ) + ); + } + @Test public void testUnnestWithTimeFilterAndAnotherFilter() { @@ -5108,6 +6783,56 @@ public void testUnnestWithTimeFilterInsideSubquery() ); } + @Test + public void testUnnestWithTimeFilterInsideSubqueryArrayColumns() + { + testQuery( + "select uln from (select * from arrays, UNNEST(arrayLongNulls) as u(uln)" + + " where __time >= TIMESTAMP '2023-01-02 00:00:00' and __time <= TIMESTAMP '2023-01-03 00:10:00' LIMIT 2) \n" + + " where ARRAY_CONTAINS(arrayLongNulls, ARRAY[2])", + QUERY_CONTEXT_UNNEST, + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + new QueryDataSource( + newScanQueryBuilder() + .dataSource( + UnnestDataSource.create( + FilteredDataSource.create( + new TableDataSource(DATA_SOURCE_ARRAYS), + range("__time", ColumnType.LONG, 1672617600000L, 1672704600000L, false, false) + ), + expressionVirtualColumn("j0.unnest", "\"arrayLongNulls\"", ColumnType.LONG_ARRAY), + null + ) + ) + .intervals(querySegmentSpec(Intervals.of( + "2023-01-02T00:00:00.000Z/2023-01-03T00:10:00.001Z"))) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .columns("arrayLongNulls", "j0.unnest") + .limit(2) + .context(QUERY_CONTEXT_UNNEST) + .build() + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + expressionFilter("array_contains(\"arrayLongNulls\",array(2))") + ) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .context(QUERY_CONTEXT_UNNEST) + .columns(ImmutableList.of("j0.unnest")) + .build() + ), + ImmutableList.of( + new Object[]{2L}, + new Object[]{3L} + ) + ); + } + @Test public void testUnnestWithFilterAndUnnestNestedBackToBack() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/TestDataBuilder.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/TestDataBuilder.java index c6f05697026f..13288f0caa0b 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/TestDataBuilder.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/TestDataBuilder.java @@ -172,7 +172,7 @@ public Optional build( ); - private static final IncrementalIndexSchema INDEX_SCHEMA = new IncrementalIndexSchema.Builder() + public static final IncrementalIndexSchema INDEX_SCHEMA = new IncrementalIndexSchema.Builder() .withMetrics( new CountAggregatorFactory("cnt"), new FloatSumAggregatorFactory("m1", "m1"), From 36edbce03667bdcb58911455b74b8e64a169c14c Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 9 Oct 2023 20:05:48 +0530 Subject: [PATCH 12/14] Fix compilation failure in master (#15111) Merging since it's a dev blocker. --- .../org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java index 8ee9e78c8388..6ec17687c45e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java @@ -95,7 +95,8 @@ public SqlEngine createEngine( queryJsonMapper, injector, new MSQTestTaskActionClient(queryJsonMapper), - workerMemoryParameters + workerMemoryParameters, + ImmutableList.of() ); return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper); } From b0edbc3d912628e936ec2af06549e1b5b8f11898 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 9 Oct 2023 20:31:07 +0530 Subject: [PATCH 13/14] MSQ writes out string arrays instead of MVDs by default (#15093) MSQ uses the string dimension schema for ARRAY typed columns, which creates MVDs instead of string arrays as required. Therefore someone trying to ingest columns of type ARRAY from an external data source or another data source would get STRING columns in the newly generated segments. This patch changes the following: - Use auto dimension schema to ingest the ARRAY columns, which will create columns with the desired type. - Add an undocumented flag ingestStringArraysAsMVDs to preserve the legacy behavior. Legacy behaviour is turned on by default. - Create MSQArraysInsertTest and refactor some of the tests in MSQInsertTest. --- .../apache/druid/msq/exec/ControllerImpl.java | 18 +- .../external/ExternalInputSliceReader.java | 5 +- .../druid/msq/util/ArrayIngestMode.java | 45 ++ .../druid/msq/util/DimensionSchemaUtils.java | 93 ++- .../msq/util/MultiStageQueryContext.java | 16 +- .../apache/druid/msq/exec/MSQArraysTest.java | 727 ++++++++++++++++++ .../apache/druid/msq/exec/MSQInsertTest.java | 289 ------- .../msq/util/MultiStageQueryContextTest.java | 159 ++-- 8 files changed, 965 insertions(+), 387 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/ArrayIngestMode.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQArraysTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 58768644bf69..6f46007d93c0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -167,6 +167,7 @@ import org.apache.druid.msq.shuffle.input.DurableStorageInputChannelFactory; import org.apache.druid.msq.shuffle.input.WorkerInputChannelFactory; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.msq.util.ArrayIngestMode; import org.apache.druid.msq.util.DimensionSchemaUtils; import org.apache.druid.msq.util.IntervalUtils; import org.apache.druid.msq.util.MSQFutureUtils; @@ -1999,6 +2000,17 @@ private static Pair, List> makeDimensio final Query query ) { + // Log a warning unconditionally if arrayIngestMode is MVD, since the behaviour is incorrect, and is subject to + // deprecation and removal in future + if (MultiStageQueryContext.getArrayIngestMode(query.context()) == ArrayIngestMode.MVD) { + log.warn( + "'%s' is set to 'mvd' in the query's context. This ingests the string arrays as multi-value " + + "strings instead of arrays, and is preserved for legacy reasons when MVDs were the only way to ingest string " + + "arrays in Druid. It is incorrect behaviour and will likely be removed in the future releases of Druid", + MultiStageQueryContext.CTX_ARRAY_INGEST_MODE + ); + } + final List dimensions = new ArrayList<>(); final List aggregators = new ArrayList<>(); @@ -2076,7 +2088,8 @@ private static Pair, List> makeDimensio DimensionSchemaUtils.createDimensionSchema( outputColumnName, type, - MultiStageQueryContext.useAutoColumnSchemas(query.context()) + MultiStageQueryContext.useAutoColumnSchemas(query.context()), + MultiStageQueryContext.getArrayIngestMode(query.context()) ) ); } else if (!isRollupQuery) { @@ -2125,7 +2138,8 @@ private static void populateDimensionsAndAggregators( DimensionSchemaUtils.createDimensionSchema( outputColumn, type, - MultiStageQueryContext.useAutoColumnSchemas(context) + MultiStageQueryContext.useAutoColumnSchemas(context), + MultiStageQueryContext.getArrayIngestMode(context) ) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java index 084d58e217d6..714e8dc3a639 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java @@ -119,10 +119,9 @@ private static Iterator inputSourceSegmentIterator( new DimensionsSpec( signature.getColumnNames().stream().map( column -> - DimensionSchemaUtils.createDimensionSchema( + DimensionSchemaUtils.createDimensionSchemaForExtern( column, - signature.getColumnType(column).orElse(null), - false + signature.getColumnType(column).orElse(null) ) ).collect(Collectors.toList()) ), diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/ArrayIngestMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/ArrayIngestMode.java new file mode 100644 index 000000000000..ff6b4718ad85 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/ArrayIngestMode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.util; + +/** + * Values that the query context flag 'arrayIngestMode' can take to specify the behaviour of ingestion of arrays via + * MSQ's INSERT queries + */ +public enum ArrayIngestMode +{ + /** + * Disables the ingestion of arrays via MSQ's INSERT queries. + */ + NONE, + + /** + * String arrays are ingested as MVDs. This is to preserve the legacy behaviour of Druid and will be removed in the + * future, since MVDs are not true array types and the behaviour is incorrect. + * This also disables the ingestion of numeric arrays + */ + MVD, + + /** + * Allows numeric and string arrays to be ingested as arrays. This should be the preferred method of ingestion, + * unless bound by compatibility reasons to use 'mvd' + */ + ARRAY +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/DimensionSchemaUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/DimensionSchemaUtils.java index 2efc94740ac7..98d94518bde8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/DimensionSchemaUtils.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/DimensionSchemaUtils.java @@ -24,7 +24,9 @@ import org.apache.druid.data.input.impl.FloatDimensionSchema; import org.apache.druid.data.input.impl.LongDimensionSchema; import org.apache.druid.data.input.impl.StringDimensionSchema; +import org.apache.druid.error.InvalidInput; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.segment.AutoTypeColumnSchema; import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnCapabilities; @@ -40,15 +42,31 @@ */ public class DimensionSchemaUtils { + + /** + * Creates a dimension schema for creating {@link org.apache.druid.data.input.InputSourceReader}. + */ + public static DimensionSchema createDimensionSchemaForExtern(final String column, @Nullable final ColumnType type) + { + return createDimensionSchema( + column, + type, + false, + // Least restrictive mode since we do not have any type restrictions while reading the extern files. + ArrayIngestMode.ARRAY + ); + } + public static DimensionSchema createDimensionSchema( final String column, @Nullable final ColumnType type, - boolean useAutoType + boolean useAutoType, + ArrayIngestMode arrayIngestMode ) { if (useAutoType) { // for complex types that are not COMPLEX, we still want to use the handler since 'auto' typing - // only works for the 'standard' built-in typesg + // only works for the 'standard' built-in types if (type != null && type.is(ValueType.COMPLEX) && !ColumnType.NESTED_DATA.equals(type)) { final ColumnCapabilities capabilities = ColumnCapabilitiesImpl.createDefault().setType(type); return DimensionHandlerUtils.getHandlerFromCapabilities(column, capabilities, null) @@ -57,35 +75,54 @@ public static DimensionSchema createDimensionSchema( return new AutoTypeColumnSchema(column); } else { - // if schema information not available, create a string dimension + // if schema information is not available, create a string dimension if (type == null) { return new StringDimensionSchema(column); - } - - switch (type.getType()) { - case STRING: - return new StringDimensionSchema(column); - case LONG: - return new LongDimensionSchema(column); - case FLOAT: - return new FloatDimensionSchema(column); - case DOUBLE: - return new DoubleDimensionSchema(column); - case ARRAY: - switch (type.getElementType().getType()) { - case STRING: - return new StringDimensionSchema(column, DimensionSchema.MultiValueHandling.ARRAY, null); - case LONG: - case FLOAT: - case DOUBLE: - return new AutoTypeColumnSchema(column); - default: - throw new ISE("Cannot create dimension for type [%s]", type.toString()); + } else if (type.getType() == ValueType.STRING) { + return new StringDimensionSchema(column); + } else if (type.getType() == ValueType.LONG) { + return new LongDimensionSchema(column); + } else if (type.getType() == ValueType.FLOAT) { + return new FloatDimensionSchema(column); + } else if (type.getType() == ValueType.DOUBLE) { + return new DoubleDimensionSchema(column); + } else if (type.getType() == ValueType.ARRAY) { + ValueType elementType = type.getElementType().getType(); + if (elementType == ValueType.STRING) { + if (arrayIngestMode == ArrayIngestMode.NONE) { + throw InvalidInput.exception( + "String arrays can not be ingested when '%s' is set to '%s'. Either set '%s' in query context " + + "to 'array' to ingest the string array as an array, or ingest it as an MVD by explicitly casting the " + + "array to an MVD with ARRAY_TO_MV function.", + MultiStageQueryContext.CTX_ARRAY_INGEST_MODE, + StringUtils.toLowerCase(arrayIngestMode.name()), + MultiStageQueryContext.CTX_ARRAY_INGEST_MODE + ); + } else if (arrayIngestMode == ArrayIngestMode.MVD) { + return new StringDimensionSchema(column, DimensionSchema.MultiValueHandling.ARRAY, null); + } else { + // arrayIngestMode == ArrayIngestMode.ARRAY would be true + return new AutoTypeColumnSchema(column); + } + } else if (elementType.isNumeric()) { + // ValueType == LONG || ValueType == FLOAT || ValueType == DOUBLE + if (arrayIngestMode == ArrayIngestMode.ARRAY) { + return new AutoTypeColumnSchema(column); + } else { + throw InvalidInput.exception( + "Numeric arrays can only be ingested when '%s' is set to 'array' in the MSQ query's context. " + + "Current value of the parameter [%s]", + MultiStageQueryContext.CTX_ARRAY_INGEST_MODE, + StringUtils.toLowerCase(arrayIngestMode.name()) + ); } - default: - final ColumnCapabilities capabilities = ColumnCapabilitiesImpl.createDefault().setType(type); - return DimensionHandlerUtils.getHandlerFromCapabilities(column, capabilities, null) - .getDimensionSchema(capabilities); + } else { + throw new ISE("Cannot create dimension for type [%s]", type.toString()); + } + } else { + final ColumnCapabilities capabilities = ColumnCapabilitiesImpl.createDefault().setType(type); + return DimensionHandlerUtils.getHandlerFromCapabilities(column, capabilities, null) + .getDimensionSchema(capabilities); } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 6e477d0c364b..613fac6203c2 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -74,6 +74,10 @@ * {@link org.apache.druid.segment.AutoTypeColumnSchema} for all 'standard' type columns during segment generation, * see {@link DimensionSchemaUtils#createDimensionSchema} for more details. * + *
  • arrayIngestMode: Tri-state query context that controls the behaviour and support of arrays that are + * ingested via MSQ. If set to 'none', arrays are not allowed to be ingested in MSQ. If set to 'array', array types + * can be ingested as expected. If set to 'mvd', numeric arrays can not be ingested, and string arrays will be + * ingested as MVDs (this is kept for legacy purpose). * **/ public class MultiStageQueryContext @@ -127,6 +131,11 @@ public class MultiStageQueryContext public static final String CTX_INDEX_SPEC = "indexSpec"; public static final String CTX_USE_AUTO_SCHEMAS = "useAutoColumnSchemas"; + public static final boolean DEFAULT_USE_AUTO_SCHEMAS = false; + + public static final String CTX_ARRAY_INGEST_MODE = "arrayIngestMode"; + public static final ArrayIngestMode DEFAULT_ARRAY_INGEST_MODE = ArrayIngestMode.MVD; + private static final Pattern LOOKS_LIKE_JSON_ARRAY = Pattern.compile("^\\s*\\[.*", Pattern.DOTALL); @@ -266,7 +275,12 @@ public static IndexSpec getIndexSpec(final QueryContext queryContext, final Obje public static boolean useAutoColumnSchemas(final QueryContext queryContext) { - return queryContext.getBoolean(CTX_USE_AUTO_SCHEMAS, false); + return queryContext.getBoolean(CTX_USE_AUTO_SCHEMAS, DEFAULT_USE_AUTO_SCHEMAS); + } + + public static ArrayIngestMode getArrayIngestMode(final QueryContext queryContext) + { + return queryContext.getEnum(CTX_ARRAY_INGEST_MODE, ArrayIngestMode.class, DEFAULT_ARRAY_INGEST_MODE); } /** diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQArraysTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQArraysTest.java new file mode 100644 index 000000000000..d2696f232820 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQArraysTest.java @@ -0,0 +1,727 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.data.input.impl.JsonInputFormat; +import org.apache.druid.data.input.impl.LocalInputSource; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.indexing.MSQTuningConfig; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.test.MSQTestBase; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.NestedDataTestUtils; +import org.apache.druid.query.Query; +import org.apache.druid.query.expression.TestExprMacroTable; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; +import org.apache.druid.sql.calcite.external.ExternalDataSource; +import org.apache.druid.sql.calcite.filtration.Filtration; +import org.apache.druid.sql.calcite.planner.ColumnMapping; +import org.apache.druid.sql.calcite.planner.ColumnMappings; +import org.apache.druid.timeline.SegmentId; +import org.apache.druid.utils.CompressionUtils; +import org.hamcrest.CoreMatchers; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Tests INSERT and SELECT behaviour of MSQ with arrays and MVDs + */ +@RunWith(Parameterized.class) +public class MSQArraysTest extends MSQTestBase +{ + + @Parameterized.Parameters(name = "{index}:with context {0}") + public static Collection data() + { + Object[][] data = new Object[][]{ + {DEFAULT, DEFAULT_MSQ_CONTEXT}, + {DURABLE_STORAGE, DURABLE_STORAGE_MSQ_CONTEXT}, + {FAULT_TOLERANCE, FAULT_TOLERANCE_MSQ_CONTEXT}, + {PARALLEL_MERGE, PARALLEL_MERGE_MSQ_CONTEXT} + }; + return Arrays.asList(data); + } + + @Parameterized.Parameter(0) + public String contextName; + + @Parameterized.Parameter(1) + public Map context; + + /** + * Tests the behaviour of INSERT query when arrayIngestMode is set to none (default) and the user tries to ingest + * string arrays + */ + @Test + public void testInsertStringArrayWithArrayIngestModeNone() + { + + final Map adjustedContext = new HashMap<>(context); + adjustedContext.put(MultiStageQueryContext.CTX_ARRAY_INGEST_MODE, "none"); + + testIngestQuery().setSql( + "INSERT INTO foo1 SELECT MV_TO_ARRAY(dim3) AS dim3 FROM foo GROUP BY 1 PARTITIONED BY ALL TIME") + .setQueryContext(adjustedContext) + .setExpectedExecutionErrorMatcher(CoreMatchers.allOf( + CoreMatchers.instanceOf(ISE.class), + ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( + "String arrays can not be ingested when 'arrayIngestMode' is set to 'none'")) + )) + .verifyExecutionError(); + } + + + /** + * Tests the behaviour of INSERT query when arrayIngestMode is set to mvd (default) and the only array type to be + * ingested is string array + */ + @Test + public void testInsertOnFoo1WithMultiValueToArrayGroupByWithDefaultContext() + { + RowSignature rowSignature = RowSignature.builder() + .add("__time", ColumnType.LONG) + .add("dim3", ColumnType.STRING) + .build(); + + testIngestQuery().setSql( + "INSERT INTO foo1 SELECT MV_TO_ARRAY(dim3) AS dim3 FROM foo GROUP BY 1 PARTITIONED BY ALL TIME") + .setExpectedDataSource("foo1") + .setExpectedRowSignature(rowSignature) + .setQueryContext(context) + .setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0))) + .setExpectedResultRows(expectedMultiValueFooRowsToArray()) + .verifyResults(); + } + + /** + * Tests the INSERT query when 'auto' type is set + */ + @Test + public void testInsertArraysAutoType() throws IOException + { + List expectedRows = Arrays.asList( + new Object[]{1672531200000L, null, null, null}, + new Object[]{1672531200000L, null, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, + new Object[]{1672531200000L, new Object[]{"d", "e"}, new Object[]{1L, 4L}, new Object[]{2.2, 3.3, 4.0}}, + new Object[]{1672531200000L, new Object[]{"a", "b"}, null, null}, + new Object[]{1672531200000L, new Object[]{"a", "b"}, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, + new Object[]{1672531200000L, new Object[]{"b", "c"}, new Object[]{1L, 2L, 3L, 4L}, new Object[]{1.1, 3.3}}, + new Object[]{1672531200000L, new Object[]{"a", "b", "c"}, new Object[]{2L, 3L}, new Object[]{3.3, 4.4, 5.5}}, + new Object[]{1672617600000L, null, null, null}, + new Object[]{1672617600000L, null, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, + new Object[]{1672617600000L, new Object[]{"d", "e"}, new Object[]{1L, 4L}, new Object[]{2.2, 3.3, 4.0}}, + new Object[]{1672617600000L, new Object[]{"a", "b"}, null, null}, + new Object[]{1672617600000L, new Object[]{"a", "b"}, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, + new Object[]{1672617600000L, new Object[]{"b", "c"}, new Object[]{1L, 2L, 3L, 4L}, new Object[]{1.1, 3.3}}, + new Object[]{1672617600000L, new Object[]{"a", "b", "c"}, new Object[]{2L, 3L}, new Object[]{3.3, 4.4, 5.5}} + ); + + RowSignature rowSignature = RowSignature.builder() + .add("__time", ColumnType.LONG) + .add("arrayString", ColumnType.STRING_ARRAY) + .add("arrayLong", ColumnType.LONG_ARRAY) + .add("arrayDouble", ColumnType.DOUBLE_ARRAY) + .build(); + + final Map adjustedContext = new HashMap<>(context); + adjustedContext.put(MultiStageQueryContext.CTX_USE_AUTO_SCHEMAS, true); + + final File tmpFile = temporaryFolder.newFile(); + final InputStream resourceStream = NestedDataTestUtils.class.getClassLoader() + .getResourceAsStream(NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); + final InputStream decompressing = CompressionUtils.decompress( + resourceStream, + NestedDataTestUtils.ARRAY_TYPES_DATA_FILE + ); + Files.copy(decompressing, tmpFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + decompressing.close(); + + final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(tmpFile); + + testIngestQuery().setSql(" INSERT INTO foo1 SELECT\n" + + " TIME_PARSE(\"timestamp\") as __time,\n" + + " arrayString,\n" + + " arrayLong,\n" + + " arrayDouble\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '[{\"name\": \"timestamp\", \"type\": \"STRING\"}, {\"name\": \"arrayString\", \"type\": \"COMPLEX\"}, {\"name\": \"arrayLong\", \"type\": \"COMPLEX\"}, {\"name\": \"arrayDouble\", \"type\": \"COMPLEX\"}]'\n" + + " )\n" + + ") PARTITIONED BY ALL") + .setQueryContext(adjustedContext) + .setExpectedResultRows(expectedRows) + .setExpectedDataSource("foo1") + .setExpectedRowSignature(rowSignature) + .verifyResults(); + } + + /** + * Tests the behaviour of INSERT query when arrayIngestMode is set to mvd and the user tries to ingest numeric array + * types as well + */ + @Test + public void testInsertArraysWithStringArraysAsMVDs() throws IOException + { + RowSignature rowSignatureWithoutTimeAndStringColumns = + RowSignature.builder() + .add("arrayLong", ColumnType.LONG_ARRAY) + .add("arrayLongNulls", ColumnType.LONG_ARRAY) + .add("arrayDouble", ColumnType.DOUBLE_ARRAY) + .add("arrayDoubleNulls", ColumnType.DOUBLE_ARRAY) + .build(); + + + RowSignature fileSignature = RowSignature.builder() + .add("timestamp", ColumnType.STRING) + .add("arrayString", ColumnType.STRING_ARRAY) + .add("arrayStringNulls", ColumnType.STRING_ARRAY) + .addAll(rowSignatureWithoutTimeAndStringColumns) + .build(); + + final Map adjustedContext = new HashMap<>(context); + adjustedContext.put(MultiStageQueryContext.CTX_ARRAY_INGEST_MODE, "mvd"); + + final File tmpFile = temporaryFolder.newFile(); + final InputStream resourceStream = NestedDataTestUtils.class.getClassLoader() + .getResourceAsStream(NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); + final InputStream decompressing = CompressionUtils.decompress( + resourceStream, + NestedDataTestUtils.ARRAY_TYPES_DATA_FILE + ); + Files.copy(decompressing, tmpFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + decompressing.close(); + + final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(tmpFile); + + testIngestQuery().setSql(" INSERT INTO foo1 SELECT\n" + + " TIME_PARSE(\"timestamp\") as __time,\n" + + " arrayString,\n" + + " arrayStringNulls,\n" + + " arrayLong,\n" + + " arrayLongNulls,\n" + + " arrayDouble,\n" + + " arrayDoubleNulls\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '" + queryFramework().queryJsonMapper().writeValueAsString(fileSignature) + "'\n" + + " )\n" + + ") PARTITIONED BY ALL") + .setQueryContext(adjustedContext) + .setExpectedExecutionErrorMatcher(CoreMatchers.allOf( + CoreMatchers.instanceOf(ISE.class), + ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( + "Numeric arrays can only be ingested when")) + )) + .verifyExecutionError(); + } + + /** + * Tests the behaviour of INSERT query when arrayIngestMode is set to array and the user tries to ingest all + * array types + */ + @Test + public void testInsertArraysAsArrays() throws IOException + { + final List expectedRows = Arrays.asList( + new Object[]{ + 1672531200000L, + null, + null, + new Object[]{1L, 2L, 3L}, + new Object[]{}, + new Object[]{1.1d, 2.2d, 3.3d}, + null + }, + new Object[]{ + 1672531200000L, + null, + new Object[]{"a", "b"}, + null, + new Object[]{2L, 3L}, + null, + new Object[]{null} + }, + new Object[]{ + 1672531200000L, + new Object[]{"d", "e"}, + new Object[]{"b", "b"}, + new Object[]{1L, 4L}, + new Object[]{1L}, + new Object[]{2.2d, 3.3d, 4.0d}, + null + }, + new Object[]{ + 1672531200000L, + new Object[]{"a", "b"}, + null, + null, + new Object[]{null, 2L, 9L}, + null, + new Object[]{999.0d, 5.5d, null} + }, + new Object[]{ + 1672531200000L, + new Object[]{"a", "b"}, + new Object[]{"a", "b"}, + new Object[]{1L, 2L, 3L}, + new Object[]{1L, null, 3L}, + new Object[]{1.1d, 2.2d, 3.3d}, + new Object[]{1.1d, 2.2d, null} + }, + new Object[]{ + 1672531200000L, + new Object[]{"b", "c"}, + new Object[]{"d", null, "b"}, + new Object[]{1L, 2L, 3L, 4L}, + new Object[]{1L, 2L, 3L}, + new Object[]{1.1d, 3.3d}, + new Object[]{null, 2.2d, null} + }, + new Object[]{ + 1672531200000L, + new Object[]{"a", "b", "c"}, + new Object[]{null, "b"}, + new Object[]{2L, 3L}, + null, + new Object[]{3.3d, 4.4d, 5.5d}, + new Object[]{999.0d, null, 5.5d} + }, + new Object[]{ + 1672617600000L, + null, + null, + new Object[]{1L, 2L, 3L}, + null, + new Object[]{1.1d, 2.2d, 3.3d}, + new Object[]{} + }, + new Object[]{ + 1672617600000L, + null, + new Object[]{"a", "b"}, + null, + new Object[]{2L, 3L}, + null, + new Object[]{null, 1.1d} + }, + new Object[]{ + 1672617600000L, + new Object[]{"d", "e"}, + new Object[]{"b", "b"}, + new Object[]{1L, 4L}, + new Object[]{null}, + new Object[]{2.2d, 3.3d, 4.0}, + null + }, + new Object[]{ + 1672617600000L, + new Object[]{"a", "b"}, + new Object[]{null}, + null, + new Object[]{null, 2L, 9L}, + null, + new Object[]{999.0d, 5.5d, null} + }, + new Object[]{ + 1672617600000L, + new Object[]{"a", "b"}, + new Object[]{}, + new Object[]{1L, 2L, 3L}, + new Object[]{1L, null, 3L}, + new Object[]{1.1d, 2.2d, 3.3d}, + new Object[]{1.1d, 2.2d, null} + }, + new Object[]{ + 1672617600000L, + new Object[]{"b", "c"}, + new Object[]{"d", null, "b"}, + new Object[]{1L, 2L, 3L, 4L}, + new Object[]{1L, 2L, 3L}, + new Object[]{1.1d, 3.3d}, + new Object[]{null, 2.2d, null} + }, + new Object[]{ + 1672617600000L, + new Object[]{"a", "b", "c"}, + new Object[]{null, "b"}, + new Object[]{2L, 3L}, + null, + new Object[]{3.3d, 4.4d, 5.5d}, + new Object[]{999.0d, null, 5.5d} + } + ); + + RowSignature rowSignatureWithoutTimeColumn = + RowSignature.builder() + .add("arrayString", ColumnType.STRING_ARRAY) + .add("arrayStringNulls", ColumnType.STRING_ARRAY) + .add("arrayLong", ColumnType.LONG_ARRAY) + .add("arrayLongNulls", ColumnType.LONG_ARRAY) + .add("arrayDouble", ColumnType.DOUBLE_ARRAY) + .add("arrayDoubleNulls", ColumnType.DOUBLE_ARRAY) + .build(); + + RowSignature fileSignature = RowSignature.builder() + .add("timestamp", ColumnType.STRING) + .addAll(rowSignatureWithoutTimeColumn) + .build(); + + RowSignature rowSignature = RowSignature.builder() + .add("__time", ColumnType.LONG) + .addAll(rowSignatureWithoutTimeColumn) + .build(); + + final Map adjustedContext = new HashMap<>(context); + adjustedContext.put(MultiStageQueryContext.CTX_ARRAY_INGEST_MODE, "array"); + + final File tmpFile = temporaryFolder.newFile(); + final InputStream resourceStream = NestedDataTestUtils.class.getClassLoader() + .getResourceAsStream(NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); + final InputStream decompressing = CompressionUtils.decompress( + resourceStream, + NestedDataTestUtils.ARRAY_TYPES_DATA_FILE + ); + Files.copy(decompressing, tmpFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + decompressing.close(); + + final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(tmpFile); + + testIngestQuery().setSql(" INSERT INTO foo1 SELECT\n" + + " TIME_PARSE(\"timestamp\") as __time,\n" + + " arrayString,\n" + + " arrayStringNulls,\n" + + " arrayLong,\n" + + " arrayLongNulls,\n" + + " arrayDouble,\n" + + " arrayDoubleNulls\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '" + queryFramework().queryJsonMapper().writeValueAsString(fileSignature) + "'\n" + + " )\n" + + ") PARTITIONED BY ALL") + .setQueryContext(adjustedContext) + .setExpectedResultRows(expectedRows) + .setExpectedDataSource("foo1") + .setExpectedRowSignature(rowSignature) + .verifyResults(); + } + + @Test + public void testSelectOnArraysWithArrayIngestModeAsNone() throws IOException + { + testSelectOnArrays("none"); + } + + @Test + public void testSelectOnArraysWithArrayIngestModeAsMVD() throws IOException + { + testSelectOnArrays("mvd"); + } + + @Test + public void testSelectOnArraysWithArrayIngestModeAsArray() throws IOException + { + testSelectOnArrays("array"); + } + + // Tests the behaviour of the select with the given arrayIngestMode. The expectation should be the same, since the + // arrayIngestMode should only determine how the array gets ingested at the end. + public void testSelectOnArrays(String arrayIngestMode) throws IOException + { + final List expectedRows = Arrays.asList( + new Object[]{ + 1672531200000L, + Arrays.asList("a", "b"), + Arrays.asList("a", "b"), + Arrays.asList(1L, 2L, 3L), + Arrays.asList(1L, null, 3L), + Arrays.asList(1.1d, 2.2d, 3.3d), + Arrays.asList(1.1d, 2.2d, null) + }, + new Object[]{ + 1672531200000L, + Arrays.asList("a", "b", "c"), + Arrays.asList(null, "b"), + Arrays.asList(2L, 3L), + null, + Arrays.asList(3.3d, 4.4d, 5.5d), + Arrays.asList(999.0d, null, 5.5d), + }, + new Object[]{ + 1672531200000L, + Arrays.asList("b", "c"), + Arrays.asList("d", null, "b"), + Arrays.asList(1L, 2L, 3L, 4L), + Arrays.asList(1L, 2L, 3L), + Arrays.asList(1.1d, 3.3d), + Arrays.asList(null, 2.2d, null) + }, + new Object[]{ + 1672531200000L, + Arrays.asList("d", "e"), + Arrays.asList("b", "b"), + Arrays.asList(1L, 4L), + Collections.singletonList(1L), + Arrays.asList(2.2d, 3.3d, 4.0d), + null + }, + new Object[]{ + 1672531200000L, + null, + null, + Arrays.asList(1L, 2L, 3L), + Collections.emptyList(), + Arrays.asList(1.1d, 2.2d, 3.3d), + null + }, + new Object[]{ + 1672531200000L, + Arrays.asList("a", "b"), + null, + null, + Arrays.asList(null, 2L, 9L), + null, + Arrays.asList(999.0d, 5.5d, null) + }, + new Object[]{ + 1672531200000L, + null, + Arrays.asList("a", "b"), + null, + Arrays.asList(2L, 3L), + null, + Collections.singletonList(null) + }, + new Object[]{ + 1672617600000L, + Arrays.asList("a", "b"), + Collections.emptyList(), + Arrays.asList(1L, 2L, 3L), + Arrays.asList(1L, null, 3L), + Arrays.asList(1.1d, 2.2d, 3.3d), + Arrays.asList(1.1d, 2.2d, null) + }, + new Object[]{ + 1672617600000L, + Arrays.asList("a", "b", "c"), + Arrays.asList(null, "b"), + Arrays.asList(2L, 3L), + null, + Arrays.asList(3.3d, 4.4d, 5.5d), + Arrays.asList(999.0d, null, 5.5d) + }, + new Object[]{ + 1672617600000L, + Arrays.asList("b", "c"), + Arrays.asList("d", null, "b"), + Arrays.asList(1L, 2L, 3L, 4L), + Arrays.asList(1L, 2L, 3L), + Arrays.asList(1.1d, 3.3d), + Arrays.asList(null, 2.2d, null) + }, + new Object[]{ + 1672617600000L, + Arrays.asList("d", "e"), + Arrays.asList("b", "b"), + Arrays.asList(1L, 4L), + Collections.singletonList(null), + Arrays.asList(2.2d, 3.3d, 4.0), + null + }, + new Object[]{ + 1672617600000L, + null, + null, + Arrays.asList(1L, 2L, 3L), + null, + Arrays.asList(1.1d, 2.2d, 3.3d), + Collections.emptyList() + }, + new Object[]{ + 1672617600000L, + Arrays.asList("a", "b"), + Collections.singletonList(null), + null, + Arrays.asList(null, 2L, 9L), + null, + Arrays.asList(999.0d, 5.5d, null) + }, + new Object[]{ + 1672617600000L, + null, + Arrays.asList("a", "b"), + null, + Arrays.asList(2L, 3L), + null, + Arrays.asList(null, 1.1d), + } + ); + + RowSignature rowSignatureWithoutTimeColumn = + RowSignature.builder() + .add("arrayString", ColumnType.STRING_ARRAY) + .add("arrayStringNulls", ColumnType.STRING_ARRAY) + .add("arrayLong", ColumnType.LONG_ARRAY) + .add("arrayLongNulls", ColumnType.LONG_ARRAY) + .add("arrayDouble", ColumnType.DOUBLE_ARRAY) + .add("arrayDoubleNulls", ColumnType.DOUBLE_ARRAY) + .build(); + + RowSignature fileSignature = RowSignature.builder() + .add("timestamp", ColumnType.STRING) + .addAll(rowSignatureWithoutTimeColumn) + .build(); + + RowSignature rowSignature = RowSignature.builder() + .add("__time", ColumnType.LONG) + .addAll(rowSignatureWithoutTimeColumn) + .build(); + + RowSignature scanSignature = RowSignature.builder() + .add("arrayDouble", ColumnType.DOUBLE_ARRAY) + .add("arrayDoubleNulls", ColumnType.DOUBLE_ARRAY) + .add("arrayLong", ColumnType.LONG_ARRAY) + .add("arrayLongNulls", ColumnType.LONG_ARRAY) + .add("arrayString", ColumnType.STRING_ARRAY) + .add("arrayStringNulls", ColumnType.STRING_ARRAY) + .add("v0", ColumnType.LONG) + .build(); + + final Map adjustedContext = new HashMap<>(context); + adjustedContext.put(MultiStageQueryContext.CTX_ARRAY_INGEST_MODE, arrayIngestMode); + + final File tmpFile = temporaryFolder.newFile(); + final InputStream resourceStream = NestedDataTestUtils.class.getClassLoader() + .getResourceAsStream(NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); + final InputStream decompressing = CompressionUtils.decompress( + resourceStream, + NestedDataTestUtils.ARRAY_TYPES_DATA_FILE + ); + Files.copy(decompressing, tmpFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + decompressing.close(); + + final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(tmpFile); + + Query expectedQuery = newScanQueryBuilder() + .dataSource(new ExternalDataSource( + new LocalInputSource(null, null, ImmutableList.of(tmpFile)), + new JsonInputFormat(null, null, null, null, null), + fileSignature + )) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns( + "arrayDouble", + "arrayDoubleNulls", + "arrayLong", + "arrayLongNulls", + "arrayString", + "arrayStringNulls", + "v0" + ) + .virtualColumns(new ExpressionVirtualColumn( + "v0", + "timestamp_parse(\"timestamp\",null,'UTC')", + ColumnType.LONG, + TestExprMacroTable.INSTANCE + )) + .context(defaultScanQueryContext(adjustedContext, scanSignature)) + .build(); + + testSelectQuery().setSql("SELECT\n" + + " TIME_PARSE(\"timestamp\") as __time,\n" + + " arrayString,\n" + + " arrayStringNulls,\n" + + " arrayLong,\n" + + " arrayLongNulls,\n" + + " arrayDouble,\n" + + " arrayDoubleNulls\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '" + queryFramework().queryJsonMapper().writeValueAsString(fileSignature) + "'\n" + + " )\n" + + ")") + .setQueryContext(adjustedContext) + .setExpectedMSQSpec(MSQSpec + .builder() + .query(expectedQuery) + .columnMappings(new ColumnMappings(ImmutableList.of( + new ColumnMapping("v0", "__time"), + new ColumnMapping("arrayString", "arrayString"), + new ColumnMapping("arrayStringNulls", "arrayStringNulls"), + new ColumnMapping("arrayLong", "arrayLong"), + new ColumnMapping("arrayLongNulls", "arrayLongNulls"), + new ColumnMapping("arrayDouble", "arrayDouble"), + new ColumnMapping("arrayDoubleNulls", "arrayDoubleNulls") + ))) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(TaskReportMSQDestination.INSTANCE) + .build() + ) + .setExpectedRowSignature(rowSignature) + .setExpectedResultRows(expectedRows) + .verifyResults(); + } + + + private List expectedMultiValueFooRowsToArray() + { + List expectedRows = new ArrayList<>(); + expectedRows.add(new Object[]{0L, null}); + if (!useDefault) { + expectedRows.add(new Object[]{0L, ""}); + } + + expectedRows.addAll(ImmutableList.of( + new Object[]{0L, ImmutableList.of("a", "b")}, + new Object[]{0L, ImmutableList.of("b", "c")}, + new Object[]{0L, "d"} + )); + return expectedRows; + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java index e54027c2449b..b43dd72e88c8 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java @@ -38,7 +38,6 @@ import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestFileUtils; import org.apache.druid.msq.util.MultiStageQueryContext; -import org.apache.druid.query.NestedDataTestUtils; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; @@ -46,7 +45,6 @@ import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; import org.apache.druid.timeline.SegmentId; -import org.apache.druid.utils.CompressionUtils; import org.hamcrest.CoreMatchers; import org.junit.Test; import org.junit.internal.matchers.ThrowableMessageMatcher; @@ -54,16 +52,11 @@ import org.junit.runners.Parameterized; import org.mockito.Mockito; -import javax.annotation.Nonnull; import java.io.File; import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -735,22 +728,6 @@ public void testInsertOnFoo1WithMultiValueMeasureGroupBy() } - @Test - public void testInsertOnFoo1WithMultiValueToArrayGroupBy() - { - RowSignature rowSignature = RowSignature.builder() - .add("__time", ColumnType.LONG) - .add("dim3", ColumnType.STRING).build(); - - testIngestQuery().setSql( - "INSERT INTO foo1 SELECT MV_TO_ARRAY(dim3) AS dim3 FROM foo GROUP BY 1 PARTITIONED BY ALL TIME") - .setExpectedDataSource("foo1") - .setExpectedRowSignature(rowSignature) - .setQueryContext(context) - .setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0))) - .setExpectedResultRows(expectedMultiValueFooRowsToArray()) - .verifyResults(); - } @Test public void testInsertOnFoo1WithAutoTypeArrayGroupBy() @@ -1407,251 +1384,6 @@ public void testCorrectNumberOfWorkersUsedAutoModeWithBytesLimit() throws IOExce .verifyResults(); } - @Test - public void testInsertArraysAutoType() throws IOException - { - List expectedRows = Arrays.asList( - new Object[]{1672531200000L, null, null, null}, - new Object[]{1672531200000L, null, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, - new Object[]{1672531200000L, new Object[]{"d", "e"}, new Object[]{1L, 4L}, new Object[]{2.2, 3.3, 4.0}}, - new Object[]{1672531200000L, new Object[]{"a", "b"}, null, null}, - new Object[]{1672531200000L, new Object[]{"a", "b"}, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, - new Object[]{1672531200000L, new Object[]{"b", "c"}, new Object[]{1L, 2L, 3L, 4L}, new Object[]{1.1, 3.3}}, - new Object[]{1672531200000L, new Object[]{"a", "b", "c"}, new Object[]{2L, 3L}, new Object[]{3.3, 4.4, 5.5}}, - new Object[]{1672617600000L, null, null, null}, - new Object[]{1672617600000L, null, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, - new Object[]{1672617600000L, new Object[]{"d", "e"}, new Object[]{1L, 4L}, new Object[]{2.2, 3.3, 4.0}}, - new Object[]{1672617600000L, new Object[]{"a", "b"}, null, null}, - new Object[]{1672617600000L, new Object[]{"a", "b"}, new Object[]{1L, 2L, 3L}, new Object[]{1.1, 2.2, 3.3}}, - new Object[]{1672617600000L, new Object[]{"b", "c"}, new Object[]{1L, 2L, 3L, 4L}, new Object[]{1.1, 3.3}}, - new Object[]{1672617600000L, new Object[]{"a", "b", "c"}, new Object[]{2L, 3L}, new Object[]{3.3, 4.4, 5.5}} - ); - - RowSignature rowSignature = RowSignature.builder() - .add("__time", ColumnType.LONG) - .add("arrayString", ColumnType.STRING_ARRAY) - .add("arrayLong", ColumnType.LONG_ARRAY) - .add("arrayDouble", ColumnType.DOUBLE_ARRAY) - .build(); - - final Map adjustedContext = new HashMap<>(context); - adjustedContext.put(MultiStageQueryContext.CTX_USE_AUTO_SCHEMAS, true); - - final File tmpFile = temporaryFolder.newFile(); - final InputStream resourceStream = NestedDataTestUtils.class.getClassLoader().getResourceAsStream(NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); - final InputStream decompressing = CompressionUtils.decompress(resourceStream, NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); - Files.copy(decompressing, tmpFile.toPath(), StandardCopyOption.REPLACE_EXISTING); - decompressing.close(); - - final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(tmpFile); - - testIngestQuery().setSql(" INSERT INTO foo1 SELECT\n" - + " TIME_PARSE(\"timestamp\") as __time,\n" - + " arrayString,\n" - + " arrayLong,\n" - + " arrayDouble\n" - + "FROM TABLE(\n" - + " EXTERN(\n" - + " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n" - + " '{\"type\": \"json\"}',\n" - + " '[{\"name\": \"timestamp\", \"type\": \"STRING\"}, {\"name\": \"arrayString\", \"type\": \"COMPLEX\"}, {\"name\": \"arrayLong\", \"type\": \"COMPLEX\"}, {\"name\": \"arrayDouble\", \"type\": \"COMPLEX\"}]'\n" - + " )\n" - + ") PARTITIONED BY day") - .setQueryContext(adjustedContext) - .setExpectedResultRows(expectedRows) - .setExpectedDataSource("foo1") - .setExpectedRowSignature(rowSignature) - .verifyResults(); - } - - @Test - public void testInsertArrays() throws IOException - { - List expectedRows = Arrays.asList( - new Object[]{ - 1672531200000L, - null, - null, - new Object[]{1L, 2L, 3L}, - new Object[]{}, - new Object[]{1.1d, 2.2d, 3.3d}, - null - }, - new Object[]{ - 1672531200000L, - null, - Arrays.asList("a", "b"), - null, - new Object[]{2L, 3L}, - null, - new Object[]{null} - }, - new Object[]{ - 1672531200000L, - Arrays.asList("a", "b"), - null, - null, - new Object[]{null, 2L, 9L}, - null, - new Object[]{999.0d, 5.5d, null} - }, - new Object[]{ - 1672531200000L, - Arrays.asList("a", "b"), - Arrays.asList("a", "b"), - new Object[]{1L, 2L, 3L}, - new Object[]{1L, null, 3L}, - new Object[]{1.1d, 2.2d, 3.3d}, - new Object[]{1.1d, 2.2d, null} - }, - new Object[]{ - 1672531200000L, - Arrays.asList("a", "b", "c"), - Arrays.asList(null, "b"), - new Object[]{2L, 3L}, - null, - new Object[]{3.3d, 4.4d, 5.5d}, - new Object[]{999.0d, null, 5.5d} - }, - new Object[]{ - 1672531200000L, - Arrays.asList("b", "c"), - Arrays.asList("d", null, "b"), - new Object[]{1L, 2L, 3L, 4L}, - new Object[]{1L, 2L, 3L}, - new Object[]{1.1d, 3.3d}, - new Object[]{null, 2.2d, null} - }, - new Object[]{ - 1672531200000L, - Arrays.asList("d", "e"), - Arrays.asList("b", "b"), - new Object[]{1L, 4L}, - new Object[]{1L}, - new Object[]{2.2d, 3.3d, 4.0d}, - null - }, - new Object[]{ - 1672617600000L, - null, - null, - new Object[]{1L, 2L, 3L}, - null, - new Object[]{1.1d, 2.2d, 3.3d}, - new Object[]{} - }, - new Object[]{ - 1672617600000L, - null, - Arrays.asList("a", "b"), - null, - new Object[]{2L, 3L}, - null, - new Object[]{null, 1.1d} - }, - new Object[]{ - 1672617600000L, - Arrays.asList("a", "b"), - null, - null, - new Object[]{null, 2L, 9L}, - null, - new Object[]{999.0d, 5.5d, null} - }, - new Object[]{ - 1672617600000L, - Arrays.asList("a", "b"), - Collections.emptyList(), - new Object[]{1L, 2L, 3L}, - new Object[]{1L, null, 3L}, - new Object[]{1.1d, 2.2d, 3.3d}, - new Object[]{1.1d, 2.2d, null} - }, - new Object[]{ - 1672617600000L, - Arrays.asList("a", "b", "c"), - Arrays.asList(null, "b"), - new Object[]{2L, 3L}, - null, - new Object[]{3.3d, 4.4d, 5.5d}, - new Object[]{999.0d, null, 5.5d} - }, - new Object[]{ - 1672617600000L, - Arrays.asList("b", "c"), - Arrays.asList("d", null, "b"), - new Object[]{1L, 2L, 3L, 4L}, - new Object[]{1L, 2L, 3L}, - new Object[]{1.1d, 3.3d}, - new Object[]{null, 2.2d, null} - }, - new Object[]{ - 1672617600000L, - Arrays.asList("d", "e"), - Arrays.asList("b", "b"), - new Object[]{1L, 4L}, - new Object[]{null}, - new Object[]{2.2d, 3.3d, 4.0}, - null - } - ); - - RowSignature rowSignatureWithoutTimeAndStringColumns = - RowSignature.builder() - .add("arrayLong", ColumnType.LONG_ARRAY) - .add("arrayLongNulls", ColumnType.LONG_ARRAY) - .add("arrayDouble", ColumnType.DOUBLE_ARRAY) - .add("arrayDoubleNulls", ColumnType.DOUBLE_ARRAY) - .build(); - - - RowSignature fileSignature = RowSignature.builder() - .add("timestamp", ColumnType.STRING) - .add("arrayString", ColumnType.STRING_ARRAY) - .add("arrayStringNulls", ColumnType.STRING_ARRAY) - .addAll(rowSignatureWithoutTimeAndStringColumns) - .build(); - - // MSQ writes strings instead of string arrays - RowSignature rowSignature = RowSignature.builder() - .add("__time", ColumnType.LONG) - .add("arrayString", ColumnType.STRING) - .add("arrayStringNulls", ColumnType.STRING) - .addAll(rowSignatureWithoutTimeAndStringColumns) - .build(); - - final Map adjustedContext = new HashMap<>(context); - final File tmpFile = temporaryFolder.newFile(); - final InputStream resourceStream = NestedDataTestUtils.class.getClassLoader().getResourceAsStream(NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); - final InputStream decompressing = CompressionUtils.decompress(resourceStream, NestedDataTestUtils.ARRAY_TYPES_DATA_FILE); - Files.copy(decompressing, tmpFile.toPath(), StandardCopyOption.REPLACE_EXISTING); - decompressing.close(); - - final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(tmpFile); - - testIngestQuery().setSql(" INSERT INTO foo1 SELECT\n" - + " TIME_PARSE(\"timestamp\") as __time,\n" - + " arrayString,\n" - + " arrayStringNulls,\n" - + " arrayLong,\n" - + " arrayLongNulls,\n" - + " arrayDouble,\n" - + " arrayDoubleNulls\n" - + "FROM TABLE(\n" - + " EXTERN(\n" - + " '{ \"files\": [" + toReadFileNameAsJson + "],\"type\":\"local\"}',\n" - + " '{\"type\": \"json\"}',\n" - + " '" + queryFramework().queryJsonMapper().writeValueAsString(fileSignature) + "'\n" - + " )\n" - + ") PARTITIONED BY day") - .setQueryContext(adjustedContext) - .setExpectedResultRows(expectedRows) - .setExpectedDataSource("foo1") - .setExpectedRowSignature(rowSignature) - .verifyResults(); - } - - @Nonnull private List expectedFooRows() { List expectedRows = new ArrayList<>(); @@ -1668,7 +1400,6 @@ private List expectedFooRows() return expectedRows; } - @Nonnull private List expectedFooRowsWithAggregatedComplexColumn() { List expectedRows = new ArrayList<>(); @@ -1687,7 +1418,6 @@ private List expectedFooRowsWithAggregatedComplexColumn() return expectedRows; } - @Nonnull private List expectedMultiValueFooRows() { List expectedRows = new ArrayList<>(); @@ -1704,24 +1434,6 @@ private List expectedMultiValueFooRows() return expectedRows; } - @Nonnull - private List expectedMultiValueFooRowsToArray() - { - List expectedRows = new ArrayList<>(); - expectedRows.add(new Object[]{0L, null}); - if (!useDefault) { - expectedRows.add(new Object[]{0L, ""}); - } - - expectedRows.addAll(ImmutableList.of( - new Object[]{0L, ImmutableList.of("a", "b")}, - new Object[]{0L, ImmutableList.of("b", "c")}, - new Object[]{0L, "d"} - )); - return expectedRows; - } - - @Nonnull private List expectedMultiValueFooRowsGroupBy() { List expectedRows = new ArrayList<>(); @@ -1737,7 +1449,6 @@ private List expectedMultiValueFooRowsGroupBy() return expectedRows; } - @Nonnull private Set expectedFooSegments() { Set expectedSegments = new TreeSet<>(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java index 830b414daedb..5bfb4d2eb279 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.Map; +import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_ARRAY_INGEST_MODE; import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE; import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_FAULT_TOLERANCE; import static org.apache.druid.msq.util.MultiStageQueryContext.CTX_FINALIZE_AGGREGATIONS; @@ -54,46 +55,46 @@ public class MultiStageQueryContextTest { @Test - public void isDurableShuffleStorageEnabled_noParameterSetReturnsDefaultValue() + public void isDurableShuffleStorageEnabled_unset_returnsDefaultValue() { Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.empty())); } @Test - public void isDurableShuffleStorageEnabled_parameterSetReturnsCorrectValue() + public void isDurableShuffleStorageEnabled_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_DURABLE_SHUFFLE_STORAGE, "true"); Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(propertyMap))); } @Test - public void isFaultToleranceEnabled_noParameterSetReturnsDefaultValue() + public void isFaultToleranceEnabled_unset_returnsDefaultValue() { Assert.assertFalse(MultiStageQueryContext.isFaultToleranceEnabled(QueryContext.empty())); } @Test - public void isFaultToleranceEnabled_parameterSetReturnsCorrectValue() + public void isFaultToleranceEnabled_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_FAULT_TOLERANCE, "true"); Assert.assertTrue(MultiStageQueryContext.isFaultToleranceEnabled(QueryContext.of(propertyMap))); } @Test - public void isFinalizeAggregations_noParameterSetReturnsDefaultValue() + public void isFinalizeAggregations_unset_returnsDefaultValue() { Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(QueryContext.empty())); } @Test - public void isFinalizeAggregations_parameterSetReturnsCorrectValue() + public void isFinalizeAggregations_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_FINALIZE_AGGREGATIONS, "false"); Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(QueryContext.of(propertyMap))); } @Test - public void getAssignmentStrategy_noParameterSetReturnsDefaultValue() + public void getAssignmentStrategy_unset_returnsDefaultValue() { Assert.assertEquals( WorkerAssignmentStrategy.MAX, @@ -102,7 +103,7 @@ public void getAssignmentStrategy_noParameterSetReturnsDefaultValue() } @Test - public void testGetMaxInputBytesPerWorker() + public void getMaxInputBytesPerWorker_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(MultiStageQueryContext.CTX_MAX_INPUT_BYTES_PER_WORKER, 1024); @@ -112,7 +113,7 @@ public void testGetMaxInputBytesPerWorker() } @Test - public void getAssignmentStrategy_parameterSetReturnsCorrectValue() + public void getAssignmentStrategy_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_TASK_ASSIGNMENT_STRATEGY, "AUTO"); Assert.assertEquals( @@ -122,27 +123,20 @@ public void getAssignmentStrategy_parameterSetReturnsCorrectValue() } @Test - public void getMaxNumTasks_noParameterSetReturnsDefaultValue() + public void getMaxNumTasks_unset_returnsDefaultValue() { Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(QueryContext.empty())); } @Test - public void getMaxNumTasks_parameterSetReturnsCorrectValue() + public void getMaxNumTasks_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101); Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap))); } @Test - public void getMaxNumTasks_legacyParameterSetReturnsCorrectValue() - { - Map propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101); - Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap))); - } - - @Test - public void getRowsPerSegment_noParameterSetReturnsDefaultValue() + public void getRowsPerSegment_unset_returnsDefaultValue() { Assert.assertEquals( MultiStageQueryContext.DEFAULT_ROWS_PER_SEGMENT, @@ -151,14 +145,14 @@ public void getRowsPerSegment_noParameterSetReturnsDefaultValue() } @Test - public void getRowsPerSegment_parameterSetReturnsCorrectValue() + public void getRowsPerSegment_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_ROWS_PER_SEGMENT, 10); Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(QueryContext.of(propertyMap))); } @Test - public void getRowsInMemory_noParameterSetReturnsDefaultValue() + public void getRowsInMemory_unset_returnsDefaultValue() { Assert.assertEquals( MultiStageQueryContext.DEFAULT_ROWS_IN_MEMORY, @@ -167,12 +161,91 @@ public void getRowsInMemory_noParameterSetReturnsDefaultValue() } @Test - public void getRowsInMemory_parameterSetReturnsCorrectValue() + public void getRowsInMemory_set_returnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_ROWS_IN_MEMORY, 10); Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(QueryContext.of(propertyMap))); } + @Test + public void getSortOrder_unset_returnsDefaultValue() + { + Assert.assertEquals(Collections.emptyList(), MultiStageQueryContext.getSortOrder(QueryContext.empty())); + } + + @Test + public void getSortOrder_set_returnsCorrectValue() + { + Map propertyMap = ImmutableMap.of(CTX_SORT_ORDER, "a, b,\"c,d\""); + Assert.assertEquals( + ImmutableList.of("a", "b", "c,d"), + MultiStageQueryContext.getSortOrder(QueryContext.of(propertyMap)) + ); + } + + @Test + public void getMSQMode_unset_returnsDefaultValue() + { + Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(QueryContext.empty())); + } + + @Test + public void getMSQMode_set_returnsCorrectValue() + { + Map propertyMap = ImmutableMap.of(CTX_MSQ_MODE, "nonStrict"); + Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(QueryContext.of(propertyMap))); + } + + @Test + public void getSelectDestination_unset_returnsDefaultValue() + { + Assert.assertEquals(MSQSelectDestination.TASKREPORT, MultiStageQueryContext.getSelectDestination(QueryContext.empty())); + } + + @Test + public void useAutoColumnSchemes_unset_returnsDefaultValue() + { + Assert.assertFalse(MultiStageQueryContext.useAutoColumnSchemas(QueryContext.empty())); + } + + @Test + public void useAutoColumnSchemes_set_returnsCorrectValue() + { + Map propertyMap = ImmutableMap.of(CTX_USE_AUTO_SCHEMAS, true); + Assert.assertTrue(MultiStageQueryContext.useAutoColumnSchemas(QueryContext.of(propertyMap))); + } + + @Test + public void arrayIngestMode_unset_returnsDefaultValue() + { + Assert.assertEquals(ArrayIngestMode.MVD, MultiStageQueryContext.getArrayIngestMode(QueryContext.empty())); + } + + @Test + public void arrayIngestMode_set_returnsCorrectValue() + { + Assert.assertEquals( + ArrayIngestMode.NONE, + MultiStageQueryContext.getArrayIngestMode(QueryContext.of(ImmutableMap.of(CTX_ARRAY_INGEST_MODE, "none"))) + ); + + Assert.assertEquals( + ArrayIngestMode.MVD, + MultiStageQueryContext.getArrayIngestMode(QueryContext.of(ImmutableMap.of(CTX_ARRAY_INGEST_MODE, "mvd"))) + ); + + Assert.assertEquals( + ArrayIngestMode.ARRAY, + MultiStageQueryContext.getArrayIngestMode(QueryContext.of(ImmutableMap.of(CTX_ARRAY_INGEST_MODE, "array"))) + ); + + Assert.assertThrows( + BadQueryContextException.class, + () -> + MultiStageQueryContext.getArrayIngestMode(QueryContext.of(ImmutableMap.of(CTX_ARRAY_INGEST_MODE, "dummy"))) + ); + } + @Test public void testDecodeSortOrder() { @@ -221,48 +294,6 @@ public void testGetIndexSpec() ); } - @Test - public void getSortOrderNoParameterSetReturnsDefaultValue() - { - Assert.assertEquals(Collections.emptyList(), MultiStageQueryContext.getSortOrder(QueryContext.empty())); - } - - @Test - public void getSortOrderParameterSetReturnsCorrectValue() - { - Map propertyMap = ImmutableMap.of(CTX_SORT_ORDER, "a, b,\"c,d\""); - Assert.assertEquals( - ImmutableList.of("a", "b", "c,d"), - MultiStageQueryContext.getSortOrder(QueryContext.of(propertyMap)) - ); - } - - @Test - public void getMSQModeNoParameterSetReturnsDefaultValue() - { - Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(QueryContext.empty())); - } - - @Test - public void getMSQModeParameterSetReturnsCorrectValue() - { - Map propertyMap = ImmutableMap.of(CTX_MSQ_MODE, "nonStrict"); - Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(QueryContext.of(propertyMap))); - } - - @Test - public void limitSelectResultReturnsDefaultValue() - { - Assert.assertEquals(MSQSelectDestination.TASKREPORT, MultiStageQueryContext.getSelectDestination(QueryContext.empty())); - } - - @Test - public void testUseAutoSchemas() - { - Map propertyMap = ImmutableMap.of(CTX_USE_AUTO_SCHEMAS, true); - Assert.assertTrue(MultiStageQueryContext.useAutoColumnSchemas(QueryContext.of(propertyMap))); - } - private static List decodeSortOrder(@Nullable final String input) { return MultiStageQueryContext.decodeSortOrder(input); From 90a1458ac9b81bd3bac443790242353bb701aadf Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Mon, 9 Oct 2023 20:45:10 +0530 Subject: [PATCH 14/14] Parse passwords containing colon correctly (#15109) --- .../BasicHTTPAuthenticator.java | 20 +++- .../BasicHTTPAuthenticatorTest.java | 109 ++++++++++-------- 2 files changed, 79 insertions(+), 50 deletions(-) diff --git a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/BasicHTTPAuthenticator.java b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/BasicHTTPAuthenticator.java index 600af931f031..85cc60d2e76b 100644 --- a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/BasicHTTPAuthenticator.java +++ b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authentication/BasicHTTPAuthenticator.java @@ -182,15 +182,27 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo return; } - String[] splits = decodedUserSecret.split(":"); - if (splits.length != 2) { + /* From https://www.rfc-editor.org/rfc/rfc7617.html, we can assume that userid won't include a colon but password + can. + + The user-id and password MUST NOT contain any control characters (see + "CTL" in Appendix B.1 of [RFC5234]). + + Furthermore, a user-id containing a colon character is invalid, as + the first colon in a user-pass string separates user-id and password + from one another; text after the first colon is part of the password. + User-ids containing colons cannot be encoded in user-pass strings. + + */ + int split = decodedUserSecret.indexOf(':'); + if (split < 0) { // The decoded user secret is not of the right format httpResp.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } - String user = splits[0]; - char[] password = splits[1].toCharArray(); + String user = decodedUserSecret.substring(0, split); + char[] password = decodedUserSecret.substring(split + 1).toCharArray(); // If any authentication error occurs we send a 401 response immediately and do not proceed further down the filter chain. // If the authentication result is null and skipOnFailure property is false, we send a 401 response and do not proceed diff --git a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/BasicHTTPAuthenticatorTest.java b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/BasicHTTPAuthenticatorTest.java index 84bfdcf56b1c..bf0cf1778a14 100644 --- a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/BasicHTTPAuthenticatorTest.java +++ b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authentication/BasicHTTPAuthenticatorTest.java @@ -112,55 +112,21 @@ public void testGoodPassword() throws IOException, ServletException } @Test - public void testGoodPasswordWithValidator() throws IOException, ServletException + public void testGoodNonEmptyPasswordWithValidator() throws IOException, ServletException { - CredentialsValidator validator = EasyMock.createMock(CredentialsValidator.class); - BasicHTTPAuthenticator authenticatorWithValidator = new BasicHTTPAuthenticator( - CACHE_MANAGER_PROVIDER, - "basic", - "basic", - null, - null, - false, - null, null, - false, - validator - ); - - String header = StringUtils.utf8Base64("userA:helloworld"); - header = StringUtils.format("Basic %s", header); - - EasyMock - .expect( - validator.validateCredentials(EasyMock.eq("basic"), EasyMock.eq("basic"), EasyMock.eq("userA"), EasyMock.aryEq("helloworld".toCharArray())) - ) - .andReturn( - new AuthenticationResult("userA", "basic", "basic", null) - ) - .times(1); - EasyMock.replay(validator); - - HttpServletRequest req = EasyMock.createMock(HttpServletRequest.class); - EasyMock.expect(req.getHeader("Authorization")).andReturn(header); - req.setAttribute( - AuthConfig.DRUID_AUTHENTICATION_RESULT, - new AuthenticationResult("userA", "basic", "basic", null) - ); - EasyMock.expectLastCall().times(1); - EasyMock.replay(req); - - HttpServletResponse resp = EasyMock.createMock(HttpServletResponse.class); - EasyMock.replay(resp); - - FilterChain filterChain = EasyMock.createMock(FilterChain.class); - filterChain.doFilter(req, resp); - EasyMock.expectLastCall().times(1); - EasyMock.replay(filterChain); + testGoodPasswordWithValidator("userA", "helloworld"); + } - Filter authenticatorFilter = authenticatorWithValidator.getFilter(); - authenticatorFilter.doFilter(req, resp, filterChain); + @Test + public void testGoodEmptyPasswordWithValidator() throws IOException, ServletException + { + testGoodPasswordWithValidator("userA", ""); + } - EasyMock.verify(req, resp, validator, filterChain); + @Test + public void testGoodColonInPasswordWithValidator() throws IOException, ServletException + { + testGoodPasswordWithValidator("userA", "hello:hello"); } @Test @@ -396,4 +362,55 @@ public void testMissingHeader() throws IOException, ServletException EasyMock.verify(req, resp, filterChain); } + + private void testGoodPasswordWithValidator(String username, String password) throws IOException, ServletException + { + CredentialsValidator validator = EasyMock.createMock(CredentialsValidator.class); + BasicHTTPAuthenticator authenticatorWithValidator = new BasicHTTPAuthenticator( + CACHE_MANAGER_PROVIDER, + "basic", + "basic", + null, + null, + false, + null, null, + false, + validator + ); + + String header = StringUtils.utf8Base64(username + ":" + password); + header = StringUtils.format("Basic %s", header); + + EasyMock + .expect( + validator.validateCredentials(EasyMock.eq("basic"), EasyMock.eq("basic"), EasyMock.eq(username), EasyMock.aryEq(password.toCharArray())) + ) + .andReturn( + new AuthenticationResult(username, "basic", "basic", null) + ) + .times(1); + EasyMock.replay(validator); + + HttpServletRequest req = EasyMock.createMock(HttpServletRequest.class); + EasyMock.expect(req.getHeader("Authorization")).andReturn(header); + req.setAttribute( + AuthConfig.DRUID_AUTHENTICATION_RESULT, + new AuthenticationResult(username, "basic", "basic", null) + ); + EasyMock.expectLastCall().times(1); + EasyMock.replay(req); + + HttpServletResponse resp = EasyMock.createMock(HttpServletResponse.class); + EasyMock.replay(resp); + + FilterChain filterChain = EasyMock.createMock(FilterChain.class); + filterChain.doFilter(req, resp); + EasyMock.expectLastCall().times(1); + EasyMock.replay(filterChain); + + Filter authenticatorFilter = authenticatorWithValidator.getFilter(); + authenticatorFilter.doFilter(req, resp, filterChain); + + EasyMock.verify(req, resp, validator, filterChain); + } }