From 786c959e9e3ea9a78f59917e3bab59d4141222af Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Tue, 3 Sep 2024 09:05:29 -0700 Subject: [PATCH] MSQ: Add limitHint to global-sort shuffles. (#16911) * MSQ: Add limitHint to global-sort shuffles. This allows pushing down limits into the SuperSorter. * Test fixes. * Add limitSpec to ScanQueryKit. Fix SuperSorter tracking. --- codestyle/spotbugs-exclude.xml | 1 + .../apache/druid/msq/exec/RunWorkOrder.java | 4 +- .../kernel/GlobalSortMaxCountShuffleSpec.java | 23 ++++- .../druid/msq/kernel/HashShuffleSpec.java | 1 - .../kernel/LimitHintJsonIncludeFilter.java | 38 ++++++++ .../apache/druid/msq/kernel/ShuffleSpec.java | 15 +++ .../msq/querykit/ShuffleSpecFactories.java | 23 +++-- .../msq/querykit/groupby/GroupByQueryKit.java | 24 ++++- .../druid/msq/querykit/scan/ScanQueryKit.java | 24 +++-- .../apache/druid/msq/exec/MSQSelectTest.java | 4 +- .../indexing/report/MSQTaskReportTest.java | 4 +- .../druid/msq/kernel/QueryDefinitionTest.java | 3 +- .../druid/msq/kernel/StageDefinitionTest.java | 6 +- .../MockQueryDefinitionBuilder.java | 3 +- .../frame/processor/FrameChannelMerger.java | 10 +- .../druid/frame/processor/SuperSorter.java | 91 ++++++++++++++----- .../frame/processor/SuperSorterTest.java | 82 +++++++++++++---- 17 files changed, 280 insertions(+), 76 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/LimitHintJsonIncludeFilter.java diff --git a/codestyle/spotbugs-exclude.xml b/codestyle/spotbugs-exclude.xml index 19da270192db..764ae7c73513 100644 --- a/codestyle/spotbugs-exclude.xml +++ b/codestyle/spotbugs-exclude.xml @@ -46,6 +46,7 @@ + diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java index e48f1ef098a6..a4d6a2180bde 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java @@ -738,7 +738,7 @@ public FrameProcessor decorate(FrameProcessor processor) makeSuperSorterIntermediateOutputChannelFactory(sorterTmpDir), memoryParameters.getSuperSorterMaxActiveProcessors(), memoryParameters.getSuperSorterMaxChannelsPerProcessor(), - -1, + stageDefinition.getShuffleSpec().limitHint(), cancellationId, counterTracker.sortProgress(), removeNullBytes @@ -871,7 +871,7 @@ public FrameProcessor decorate(FrameProcessor processor) makeSuperSorterIntermediateOutputChannelFactory(sorterTmpDir), 1, 2, - -1, + ShuffleSpec.UNLIMITED, cancellationId, // Tracker is not actually tracked, since it doesn't quite fit into the way we report counters. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java index e38fc778bb8a..35576474ff9d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java @@ -43,17 +43,20 @@ public class GlobalSortMaxCountShuffleSpec implements GlobalSortShuffleSpec private final ClusterBy clusterBy; private final int maxPartitions; private final boolean aggregate; + private final long limitHint; @JsonCreator public GlobalSortMaxCountShuffleSpec( @JsonProperty("clusterBy") final ClusterBy clusterBy, @JsonProperty("partitions") final int maxPartitions, - @JsonProperty("aggregate") final boolean aggregate + @JsonProperty("aggregate") final boolean aggregate, + @JsonProperty("limitHint") final Long limitHint ) { this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy"); this.maxPartitions = maxPartitions; this.aggregate = aggregate; + this.limitHint = limitHint == null ? UNLIMITED : limitHint; if (maxPartitions < 1) { throw new IAE("Partition count must be at least 1"); @@ -133,6 +136,14 @@ public int getMaxPartitions() return maxPartitions; } + @Override + @JsonInclude(value = JsonInclude.Include.CUSTOM, valueFilter = LimitHintJsonIncludeFilter.class) + @JsonProperty + public long limitHint() + { + return limitHint; + } + @Override public boolean equals(Object o) { @@ -145,22 +156,24 @@ public boolean equals(Object o) GlobalSortMaxCountShuffleSpec that = (GlobalSortMaxCountShuffleSpec) o; return maxPartitions == that.maxPartitions && aggregate == that.aggregate - && Objects.equals(clusterBy, that.clusterBy); + && Objects.equals(clusterBy, that.clusterBy) + && Objects.equals(limitHint, that.limitHint); } @Override public int hashCode() { - return Objects.hash(clusterBy, maxPartitions, aggregate); + return Objects.hash(clusterBy, maxPartitions, aggregate, limitHint); } @Override public String toString() { - return "MaxCountShuffleSpec{" + + return "GlobalSortMaxCountShuffleSpec{" + "clusterBy=" + clusterBy + - ", partitions=" + maxPartitions + + ", maxPartitions=" + maxPartitions + ", aggregate=" + aggregate + + ", limitHint=" + limitHint + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java index fbc39fc672c3..69e66fffe263 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java @@ -65,5 +65,4 @@ public int partitionCount() { return numPartitions; } - } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/LimitHintJsonIncludeFilter.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/LimitHintJsonIncludeFilter.java new file mode 100644 index 000000000000..55572798e279 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/LimitHintJsonIncludeFilter.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel; + +import com.fasterxml.jackson.annotation.JsonInclude; + +/** + * {@link JsonInclude} filter for {@link ShuffleSpec#limitHint()}. + * + * This API works by "creative" use of equals. It requires warnings to be suppressed + * and also requires spotbugs exclusions (see spotbugs-exclude.xml). + */ +@SuppressWarnings({"EqualsAndHashcode", "EqualsHashCode"}) +public class LimitHintJsonIncludeFilter +{ + @Override + public boolean equals(Object obj) + { + return obj instanceof Long && (Long) obj == ShuffleSpec.UNLIMITED; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java index 4b7971a7f783..97f3e6db5473 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java @@ -37,6 +37,8 @@ }) public interface ShuffleSpec { + long UNLIMITED = -1; + /** * The nature of this shuffle: hash vs. range based partitioning; whether the data are sorted or not. * @@ -68,4 +70,17 @@ public interface ShuffleSpec * @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT} with more than one target partition */ int partitionCount(); + + /** + * Limit that can be applied during shuffling. This is provided to enable performance optimizations. + * + * Implementations may apply this limit to each partition individually, or may apply it to the entire resultset + * (across all partitions). Either approach is valid, so downstream logic must handle either one. + * + * Implementations may also ignore this hint completely, or may apply a limit that is somewhat higher than this hint. + */ + default long limitHint() + { + return UNLIMITED; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java index d28439c0f8e0..8f5770378c55 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java @@ -22,6 +22,7 @@ import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec; import org.apache.druid.msq.kernel.GlobalSortTargetSizeShuffleSpec; import org.apache.druid.msq.kernel.MixShuffleSpec; +import org.apache.druid.msq.kernel.ShuffleSpec; /** * Static factory methods for common implementations of {@link ShuffleSpecFactory}. @@ -37,10 +38,21 @@ private ShuffleSpecFactories() * Factory that produces a single output partition, which may or may not be sorted. */ public static ShuffleSpecFactory singlePartition() + { + return singlePartitionWithLimit(ShuffleSpec.UNLIMITED); + } + + /** + * Factory that produces a single output partition, which may or may not be sorted. + * + * @param limitHint limit that can be applied during shuffling. May not actually be applied; this is just an + * optional optimization. See {@link ShuffleSpec#limitHint()}. + */ + public static ShuffleSpecFactory singlePartitionWithLimit(final long limitHint) { return (clusterBy, aggregate) -> { if (clusterBy.sortable() && !clusterBy.isEmpty()) { - return new GlobalSortMaxCountShuffleSpec(clusterBy, 1, aggregate); + return new GlobalSortMaxCountShuffleSpec(clusterBy, 1, aggregate, limitHint); } else { return MixShuffleSpec.instance(); } @@ -52,7 +64,8 @@ public static ShuffleSpecFactory singlePartition() */ public static ShuffleSpecFactory globalSortWithMaxPartitionCount(final int partitions) { - return (clusterBy, aggregate) -> new GlobalSortMaxCountShuffleSpec(clusterBy, partitions, aggregate); + return (clusterBy, aggregate) -> + new GlobalSortMaxCountShuffleSpec(clusterBy, partitions, aggregate, ShuffleSpec.UNLIMITED); } /** @@ -61,10 +74,6 @@ public static ShuffleSpecFactory globalSortWithMaxPartitionCount(final int parti public static ShuffleSpecFactory getGlobalSortWithTargetSize(int targetSize) { return (clusterBy, aggregate) -> - new GlobalSortTargetSizeShuffleSpec( - clusterBy, - targetSize, - aggregate - ); + new GlobalSortTargetSizeShuffleSpec(clusterBy, targetSize, aggregate); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java index 2bf77fd8d0cf..7e4ebf5e7fab 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java @@ -112,13 +112,25 @@ public QueryDefinition makeQueryDefinition( final ShuffleSpecFactory shuffleSpecFactoryPostAggregation; boolean partitionBoost; + // limitHint to use for the shuffle after the post-aggregation stage. + // Don't apply limitHint pre-aggregation, because results from pre-aggregation may not be fully grouped. + final long postAggregationLimitHint; + + if (doLimitOrOffset) { + final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec(); + postAggregationLimitHint = + limitSpec.isLimited() ? limitSpec.getOffset() + limitSpec.getLimit() : ShuffleSpec.UNLIMITED; + } else { + postAggregationLimitHint = ShuffleSpec.UNLIMITED; + } + if (intermediateClusterBy.isEmpty() && resultClusterByWithoutPartitionBoost.isEmpty()) { // Ignore shuffleSpecFactory, since we know only a single partition will come out, and we can save some effort. // This condition will be triggered when we don't have a grouping dimension, no partitioning granularity // (PARTITIONED BY ALL) and no ordering/clustering dimensions // For example: INSERT INTO foo SELECT COUNT(*) FROM bar PARTITIONED BY ALL shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.singlePartition(); - shuffleSpecFactoryPostAggregation = ShuffleSpecFactories.singlePartition(); + shuffleSpecFactoryPostAggregation = ShuffleSpecFactories.singlePartitionWithLimit(postAggregationLimitHint); partitionBoost = false; } else if (doOrderBy) { // There can be a situation where intermediateClusterBy is empty, while the resultClusterBy is non-empty @@ -130,9 +142,13 @@ public QueryDefinition makeQueryDefinition( shuffleSpecFactoryPreAggregation = intermediateClusterBy.isEmpty() ? ShuffleSpecFactories.singlePartition() : ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount); - shuffleSpecFactoryPostAggregation = doLimitOrOffset - ? ShuffleSpecFactories.singlePartition() - : resultShuffleSpecFactory; + + if (doLimitOrOffset) { + shuffleSpecFactoryPostAggregation = ShuffleSpecFactories.singlePartitionWithLimit(postAggregationLimitHint); + } else { + shuffleSpecFactoryPostAggregation = resultShuffleSpecFactory; + } + partitionBoost = true; } else { shuffleSpecFactoryPreAggregation = doLimitOrOffset diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java index 5bbf9c9cbb05..f4f50106e813 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java @@ -143,22 +143,32 @@ public QueryDefinition makeQueryDefinition( ); ShuffleSpec scanShuffleSpec; - if (!hasLimitOrOffset) { - // If there is no limit spec, apply the final shuffling here itself. This will ensure partition sizes etc are respected. - scanShuffleSpec = finalShuffleSpec; - } else { + if (hasLimitOrOffset) { // If there is a limit spec, check if there are any non-boost columns to sort in. - boolean requiresSort = clusterByColumns.stream() - .anyMatch(keyColumn -> !QueryKitUtils.PARTITION_BOOST_COLUMN.equals(keyColumn.columnName())); + boolean requiresSort = + clusterByColumns.stream() + .anyMatch(keyColumn -> !QueryKitUtils.PARTITION_BOOST_COLUMN.equals(keyColumn.columnName())); if (requiresSort) { // If yes, do a sort into a single partition. - scanShuffleSpec = ShuffleSpecFactories.singlePartition().build(clusterBy, false); + final long limitHint; + + if (queryToRun.isLimited() + && queryToRun.getScanRowsOffset() + queryToRun.getScanRowsLimit() > 0 /* overflow check */) { + limitHint = queryToRun.getScanRowsOffset() + queryToRun.getScanRowsLimit(); + } else { + limitHint = ShuffleSpec.UNLIMITED; + } + + scanShuffleSpec = ShuffleSpecFactories.singlePartitionWithLimit(limitHint).build(clusterBy, false); } else { // If the only clusterBy column is the boost column, we just use a mix shuffle to avoid unused shuffling. // Note that we still need the boost column to be present in the row signature, since the limit stage would // need it to be populated to do its own shuffling later. scanShuffleSpec = MixShuffleSpec.instance(); } + } else { + // If there is no limit spec, apply the final shuffling here itself. This will ensure partition sizes etc are respected. + scanShuffleSpec = finalShuffleSpec; } queryDefBuilder.add( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index b80b3844b502..305cdebf691e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -1847,7 +1847,7 @@ public void testGroupByWithLimitAndOrdering(String contextName, Map runIncrementally(final IntSet readableInputs) throws return ReturnOrAwait.awaitAll(awaitSet); } + // Check finished() after populateCurrentFramesAndTournamentTree(). if (finished()) { - // Done! return ReturnOrAwait.returnObject(rowsOutput); } // Generate one output frame and stop for now. outputChannel.write(nextFrame()); - return ReturnOrAwait.runAgain(); + + // Check finished() after nextFrame(). + if (finished()) { + return ReturnOrAwait.returnObject(rowsOutput); + } else { + return ReturnOrAwait.runAgain(); + } } private FrameWithPartition nextFrame() diff --git a/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java b/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java index d2b6934d2926..e30f2e77b02b 100644 --- a/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java +++ b/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java @@ -123,6 +123,7 @@ public class SuperSorter { private static final Logger log = new Logger(SuperSorter.class); + public static final long UNLIMITED = -1; public static final int UNKNOWN_LEVEL = -1; public static final long UNKNOWN_TOTAL = -1; @@ -136,10 +137,8 @@ public class SuperSorter private final OutputChannelFactory intermediateOutputChannelFactory; private final int maxChannelsPerMerger; private final int maxActiveProcessors; - private final long rowLimit; private final String cancellationId; private final boolean removeNullBytes; - private final Object runWorkersLock = new Object(); @GuardedBy("runWorkersLock") @@ -184,6 +183,9 @@ public class SuperSorter @GuardedBy("runWorkersLock") SuperSorterProgressTracker superSorterProgressTracker; + @GuardedBy("runWorkersLock") + private long rowLimit; + /** * See {@link #setNoWorkRunnable}. */ @@ -212,9 +214,9 @@ public class SuperSorter * @param maxChannelsPerMerger maximum number of channels to merge at once, for regular mergers * (does not apply to direct mergers; see * {@link #getMaxInputBufferFramesForDirectMerging()}) - * @param rowLimit limit to apply during sorting. The limit is merely advisory: the actual number - * of rows returned may be larger than the limit. The limit is applied across - * all partitions, not to each partition individually. + * @param rowLimit limit to apply during sorting. The limit is applied across all partitions, + * not to each partition individually. Use {@link #UNLIMITED} if there is + * no limit. * @param cancellationId cancellation id to use when running processors in the provided * {@link FrameProcessorExecutor}. * @param superSorterProgressTracker progress tracker @@ -262,6 +264,10 @@ public SuperSorter( if (maxChannelsPerMerger < 2) { throw new IAE("maxChannelsPerMerger[%d] < 2", maxChannelsPerMerger); } + + if (rowLimit != UNLIMITED && rowLimit <= 0) { + throw new IAE("rowLimit[%d] must be positive", rowLimit); + } } /** @@ -385,30 +391,42 @@ private void runWorkersIfPossible() @GuardedBy("runWorkersLock") private void setAllDoneIfPossible() { - if (totalInputFrames == 0 && outputPartitionsFuture.isDone()) { - // No input data -- generate empty output channels. - final ClusterByPartitions partitions = getOutputPartitions(); - final List channels = new ArrayList<>(partitions.size()); + try { + if (totalInputFrames == 0 && outputPartitionsFuture.isDone()) { + // No input data -- generate empty output channels. + final ClusterByPartitions partitions = getOutputPartitions(); + final List channels = new ArrayList<>(partitions.size()); - for (int partitionNum = 0; partitionNum < partitions.size(); partitionNum++) { - channels.add(outputChannelFactory.openNilChannel(partitionNum)); - } + for (int partitionNum = 0; partitionNum < partitions.size(); partitionNum++) { + channels.add(outputChannelFactory.openNilChannel(partitionNum)); + } + + // OK to use wrap, not wrapReadOnly, because nil channels are already read-only. + allDone.set(OutputChannels.wrap(channels)); + } else if (rowLimit == 0 && activeProcessors == 0) { + // We had a row limit, and got it all the way down to zero. + // Generate empty output channels for any partitions that we haven't written yet. + for (int partitionNum = 0; partitionNum < outputChannels.size(); partitionNum++) { + if (outputChannels.get(partitionNum) == null) { + outputChannels.set(partitionNum, outputChannelFactory.openNilChannel(partitionNum)); + superSorterProgressTracker.addMergedBatchesForLevel(totalMergingLevels - 1, 1); + } + } - // OK to use wrap, not wrapReadOnly, because nil channels are already read-only. - allDone.set(OutputChannels.wrap(channels)); - } else if (totalMergingLevels != UNKNOWN_LEVEL - && outputsReadyByLevel.containsKey(totalMergingLevels - 1) - && (outputsReadyByLevel.get(totalMergingLevels - 1).size() == - getTotalMergersInLevel(totalMergingLevels - 1))) { - // We're done!! - try { // OK to use wrap, not wrapReadOnly, because all channels in this list are already read-only. allDone.set(OutputChannels.wrap(outputChannels)); - } - catch (Throwable e) { - allDone.setException(e); + } else if (totalMergingLevels != UNKNOWN_LEVEL + && outputsReadyByLevel.containsKey(totalMergingLevels - 1) + && (outputsReadyByLevel.get(totalMergingLevels - 1).size() == + getTotalMergersInLevel(totalMergingLevels - 1))) { + // We're done!! + // OK to use wrap, not wrapReadOnly, because all channels in this list are already read-only. + allDone.set(OutputChannels.wrap(outputChannels)); } } + catch (Throwable e) { + allDone.setException(e); + } } @GuardedBy("runWorkersLock") @@ -463,6 +481,11 @@ && ultimateMergersRunSoFar < getTotalMergersInLevel(0))) { return false; } + if (isLimited() && (rowLimit == 0 || activeProcessors > 0)) { + // Run final-layer mergers one at a time, to ensure limit is applied across the entire dataset. + return false; + } + final List in = new ArrayList<>(); for (final Frame frame : inputBuffer) { @@ -617,6 +640,11 @@ && ultimateMergersRunSoFar < getOutputPartitions().size())) { return false; } + if (isLimited() && (rowLimit == 0 || activeProcessors > 0)) { + // Run final-layer mergers one at a time, to ensure limit is applied across the entire dataset. + return false; + } + final int inLevel = totalMergingLevels - 2; final int outLevel = inLevel + 1; final LongSortedSet inputsReady = outputsReadyByLevel.get(inLevel); @@ -719,11 +747,20 @@ private void runMerger( rowLimit ); - runWorker(worker, ignored1 -> { + runWorker(worker, outputRows -> { synchronized (runWorkersLock) { outputsReadyByLevel.computeIfAbsent(level, ignored2 -> new LongRBTreeSet()) .add(rank); superSorterProgressTracker.addMergedBatchesForLevel(level, 1); + + if (isLimited() && totalMergingLevels != UNKNOWN_LEVEL && level == totalMergingLevels - 1) { + rowLimit -= outputRows; + + if (rowLimit < 0) { + throw DruidException.defensive("rowLimit[%d] below zero after outputRows[%d]", rowLimit, outputRows); + } + } + for (PartitionedReadableFrameChannel partitionedReadableFrameChannel : partitionedReadableChannelsToClose) { try { partitionedReadableFrameChannel.close(); @@ -984,6 +1021,12 @@ private int getMaxInputBufferFramesForDirectMerging() return maxChannelsPerMerger * maxActiveProcessors; } + @GuardedBy("runWorkersLock") + private boolean isLimited() + { + return rowLimit != UNLIMITED; + } + /** * Returns a string encapsulating the current state of this object. */ diff --git a/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java b/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java index 36644f6d7715..ab5b1c8197fb 100644 --- a/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java +++ b/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; +import com.google.common.primitives.Ints; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -58,6 +59,8 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.testing.InitializedNullHandlingTest; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -131,7 +134,7 @@ public void testSingleEmptyInputChannel_fileStorage() throws Exception new FileOutputChannelFactory(tempFolder, FRAME_SIZE, null), 2, 2, - -1, + SuperSorter.UNLIMITED, null, superSorterProgressTracker, false @@ -181,6 +184,40 @@ public void testSingleEmptyInputChannel_immediately_fileStorage() throws Excepti Assert.assertEquals(1.0, superSorterProgressTracker.snapshot().getProgressDigest(), 0.0f); channel.close(); } + + @Test + public void testLimitHint() throws Exception + { + final BlockingQueueFrameChannel inputChannel = BlockingQueueFrameChannel.minimal(); + inputChannel.writable().close(); + + final SuperSorterProgressTracker superSorterProgressTracker = new SuperSorterProgressTracker(); + + final File tempFolder = temporaryFolder.newFolder(); + final SuperSorter superSorter = new SuperSorter( + Collections.singletonList(inputChannel.readable()), + FrameReader.create(RowSignature.empty()), + Collections.emptyList(), + Futures.immediateFuture(ClusterByPartitions.oneUniversalPartition()), + exec, + new FileOutputChannelFactory(tempFolder, FRAME_SIZE, null), + new FileOutputChannelFactory(tempFolder, FRAME_SIZE, null), + 2, + 2, + 3, + null, + superSorterProgressTracker, + false + ); + + final OutputChannels channels = superSorter.run().get(); + Assert.assertEquals(1, channels.getAllChannels().size()); + + final ReadableFrameChannel channel = Iterables.getOnlyElement(channels.getAllChannels()).getReadableChannel(); + Assert.assertTrue(channel.isFinished()); + Assert.assertEquals(1.0, superSorterProgressTracker.snapshot().getProgressDigest(), 0.0f); + channel.close(); + } } /** @@ -201,6 +238,7 @@ public static class ParameterizedCasesTest extends InitializedNullHandlingTest private final int numThreads; private final boolean isComposedStorage; private final boolean partitionsDeferred; + private final long limitHint; private StorageAdapter adapter; private RowSignature signature; @@ -216,7 +254,8 @@ public ParameterizedCasesTest( int maxChannelsPerProcessor, int numThreads, boolean isComposedStorage, - boolean partitionsDeferred + boolean partitionsDeferred, + long limitHint ) { this.maxRowsPerFrame = maxRowsPerFrame; @@ -227,6 +266,7 @@ public ParameterizedCasesTest( this.numThreads = numThreads; this.isComposedStorage = isComposedStorage; this.partitionsDeferred = partitionsDeferred; + this.limitHint = limitHint; } @Parameterized.Parameters( @@ -237,7 +277,8 @@ public ParameterizedCasesTest( + "maxChannelsPerProcessor= {4}, " + "numThreads = {5}, " + "isComposedStorage = {6}, " - + "partitionsDeferred = {7}" + + "partitionsDeferred = {7}, " + + "limitHint = {8}" ) public static Iterable constructorFeeder() { @@ -251,18 +292,21 @@ public static Iterable constructorFeeder() for (int numThreads : new int[]{1, 3}) { for (boolean isComposedStorage : new boolean[]{true, false}) { for (boolean partitionsDeferred : new boolean[]{true, false}) { - constructors.add( - new Object[]{ - maxRowsPerFrame, - maxBytesPerFrame, - numChannels, - maxActiveProcessors, - maxChannelsPerProcessor, - numThreads, - isComposedStorage, - partitionsDeferred - } - ); + for (long limitHint : new long[]{SuperSorter.UNLIMITED, 3, 1_000}) { + constructors.add( + new Object[]{ + maxRowsPerFrame, + maxBytesPerFrame, + numChannels, + maxActiveProcessors, + maxChannelsPerProcessor, + numThreads, + isComposedStorage, + partitionsDeferred, + limitHint + } + ); + } } } } @@ -352,7 +396,7 @@ private void verifySuperSorter( outputChannelFactory, maxActiveProcessors, maxChannelsPerProcessor, - -1, + limitHint, null, superSorterProgressTracker, false @@ -415,6 +459,10 @@ private void verifySuperSorter( ); } + if (limitHint != SuperSorter.UNLIMITED) { + MatcherAssert.assertThat(readRows.size(), Matchers.greaterThanOrEqualTo(Ints.checkedCast(limitHint))); + } + final Sequence> expectedRows = Sequences.sort( FrameTestUtil.readRowsFromAdapter(adapter, signature, true), Comparator.comparing( @@ -429,7 +477,7 @@ private void verifySuperSorter( }, keyComparator ) - ); + ).limit(limitHint == SuperSorter.UNLIMITED ? Long.MAX_VALUE : readRows.size()); FrameTestUtil.assertRowsEqual(expectedRows, Sequences.simple(readRows)); }