diff --git a/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java b/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java index 517235a99f989..ca2708700f083 100644 --- a/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java +++ b/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java @@ -19,10 +19,9 @@ package org.apache.druid.java.util.common.guava; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; -import org.apache.druid.java.util.common.RE; +import com.google.common.util.concurrent.AbstractFuture; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.QueryTimeoutException; @@ -63,6 +62,7 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase { private static final Logger LOG = new Logger(ParallelMergeCombiningSequence.class); + private static final long BLOCK_TIMEOUT = TimeUnit.NANOSECONDS.convert(500, TimeUnit.MILLISECONDS); // these values were chosen carefully via feedback from benchmarks, // see PR https://github.com/apache/druid/pull/8578 for details @@ -84,7 +84,7 @@ public class ParallelMergeCombiningSequence extends YieldingSequenceBase private final long targetTimeNanos; private final Consumer metricsReporter; - private final CancellationGizmo cancellationGizmo; + private final CancellationFuture cancellationFuture; public ParallelMergeCombiningSequence( ForkJoinPool workerPool, @@ -114,14 +114,24 @@ public ParallelMergeCombiningSequence( this.targetTimeNanos = TimeUnit.NANOSECONDS.convert(targetTimeMillis, TimeUnit.MILLISECONDS); this.queueSize = (1 << 15) / batchSize; // each queue can by default hold ~32k rows this.metricsReporter = reporter; - this.cancellationGizmo = new CancellationGizmo(); + this.cancellationFuture = new CancellationFuture(new CancellationGizmo()); } @Override public Yielder toYielder(OutType initValue, YieldingAccumulator accumulator) { if (inputSequences.isEmpty()) { - return Sequences.empty().toYielder(initValue, accumulator); + return Sequences.wrap( + Sequences.empty(), + new SequenceWrapper() + { + @Override + public void after(boolean isDone, Throwable thrown) + { + cancellationFuture.set(true); + } + } + ).toYielder(initValue, accumulator); } // we make final output queue larger than the merging queues so if downstream readers are slower to read there is // less chance of blocking the merge @@ -144,27 +154,43 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat hasTimeout, timeoutAtNanos, metricsAccumulator, - cancellationGizmo + cancellationFuture.cancellationGizmo ); workerPool.execute(mergeCombineAction); - Sequence finalOutSequence = makeOutputSequenceForQueue( - outputQueue, - hasTimeout, - timeoutAtNanos, - cancellationGizmo - ).withBaggage(() -> { - if (metricsReporter != null) { - metricsAccumulator.setTotalWallTime(System.nanoTime() - startTimeNanos); - metricsReporter.accept(metricsAccumulator.build()); - } - }); + + final Sequence finalOutSequence = Sequences.wrap( + makeOutputSequenceForQueue( + outputQueue, + hasTimeout, + timeoutAtNanos, + cancellationFuture.cancellationGizmo + ), + new SequenceWrapper() + { + @Override + public void after(boolean isDone, Throwable thrown) + { + if (isDone) { + cancellationFuture.set(true); + } else { + cancellationFuture.cancel(true); + } + if (metricsReporter != null) { + metricsAccumulator.setTotalWallTime(System.nanoTime() - startTimeNanos); + metricsReporter.accept(metricsAccumulator.build()); + } + } + } + ); return finalOutSequence.toYielder(initValue, accumulator); } - @VisibleForTesting - public CancellationGizmo getCancellationGizmo() + /** + * + */ + public CancellationFuture getCancellationFuture() { - return cancellationGizmo; + return cancellationFuture; } /** @@ -181,8 +207,6 @@ static Sequence makeOutputSequenceForQueue( return new BaseSequence<>( new BaseSequence.IteratorMaker>() { - private boolean shouldCancelOnCleanup = true; - @Override public Iterator make() { @@ -195,7 +219,7 @@ public boolean hasNext() { final long thisTimeoutNanos = timeoutAtNanos - System.nanoTime(); if (hasTimeout && thisTimeoutNanos < 0) { - throw new QueryTimeoutException(); + throw cancellationGizmo.cancelAndThrow(new QueryTimeoutException()); } if (currentBatch != null && !currentBatch.isTerminalResult() && !currentBatch.isDrained()) { @@ -210,33 +234,32 @@ public boolean hasNext() } } if (currentBatch == null) { - throw new QueryTimeoutException(); + throw cancellationGizmo.cancelAndThrow(new QueryTimeoutException()); } - if (cancellationGizmo.isCancelled()) { + if (cancellationGizmo.isCanceled()) { throw cancellationGizmo.getRuntimeException(); } if (currentBatch.isTerminalResult()) { - shouldCancelOnCleanup = false; return false; } return true; } catch (InterruptedException e) { - throw new RE(e); + throw cancellationGizmo.cancelAndThrow(e); } } @Override public T next() { - if (cancellationGizmo.isCancelled()) { + if (cancellationGizmo.isCanceled()) { throw cancellationGizmo.getRuntimeException(); } if (currentBatch == null || currentBatch.isDrained() || currentBatch.isTerminalResult()) { - throw new NoSuchElementException(); + throw cancellationGizmo.cancelAndThrow(new NoSuchElementException()); } return currentBatch.next(); } @@ -246,9 +269,7 @@ public T next() @Override public void cleanup(Iterator iterFromMake) { - if (shouldCancelOnCleanup) { - cancellationGizmo.cancel(new RuntimeException("Already closed")); - } + // nothing to cleanup } } ); @@ -338,7 +359,7 @@ protected void compute() parallelTaskCount ); - QueuePusher> resultsPusher = new QueuePusher<>(out, hasTimeout, timeoutAt); + QueuePusher resultsPusher = new QueuePusher<>(out, cancellationGizmo, hasTimeout, timeoutAt); for (Sequence s : sequences) { sequenceCursors.add(new YielderBatchedResultsCursor<>(new SequenceBatcher<>(s, batchSize), orderingFn)); @@ -367,10 +388,10 @@ protected void compute() catch (Throwable t) { closeAllCursors(sequenceCursors); cancellationGizmo.cancel(t); - // Should be the following, but can' change due to lack of - // unit tests. - // out.offer((ParallelMergeCombiningSequence.ResultBatch) ResultBatch.TERMINAL); - out.offer(ResultBatch.TERMINAL); + // offer terminal result if queue is not full in case out is empty to allow downstream threads waiting on + // stuff to be present to stop blocking immediately. However, if the queue is full, it doesn't matter if we + // write anything because the cancellation signal has been set, which will also terminate processing. + out.offer(ResultBatch.terminal()); } } @@ -387,7 +408,7 @@ private void spawnParallelTasks(int parallelMergeTasks) for (List> partition : partitions) { BlockingQueue> outputQueue = new ArrayBlockingQueue<>(queueSize); intermediaryOutputs.add(outputQueue); - QueuePusher> pusher = new QueuePusher<>(outputQueue, hasTimeout, timeoutAt); + QueuePusher pusher = new QueuePusher<>(outputQueue, cancellationGizmo, hasTimeout, timeoutAt); List> partitionCursors = new ArrayList<>(sequences.size()); for (Sequence s : partition) { @@ -415,11 +436,11 @@ private void spawnParallelTasks(int parallelMergeTasks) getPool().execute(task); } - QueuePusher> outputPusher = new QueuePusher<>(out, hasTimeout, timeoutAt); + QueuePusher outputPusher = new QueuePusher<>(out, cancellationGizmo, hasTimeout, timeoutAt); List> intermediaryOutputsCursors = new ArrayList<>(intermediaryOutputs.size()); for (BlockingQueue> queue : intermediaryOutputs) { intermediaryOutputsCursors.add( - new BlockingQueueuBatchedResultsCursor<>(queue, orderingFn, hasTimeout, timeoutAt) + new BlockingQueueuBatchedResultsCursor<>(queue, cancellationGizmo, orderingFn, hasTimeout, timeoutAt) ); } MergeCombineActionMetricsAccumulator finalMergeMetrics = new MergeCombineActionMetricsAccumulator(); @@ -513,7 +534,7 @@ private static class MergeCombineAction extends RecursiveAction private final PriorityQueue> pQueue; private final Ordering orderingFn; private final BinaryOperator combineFn; - private final QueuePusher> outputQueue; + private final QueuePusher outputQueue; private final T initialValue; private final int yieldAfter; private final int batchSize; @@ -523,7 +544,7 @@ private static class MergeCombineAction extends RecursiveAction private MergeCombineAction( PriorityQueue> pQueue, - QueuePusher> outputQueue, + QueuePusher outputQueue, Ordering orderingFn, BinaryOperator combineFn, T initialValue, @@ -550,6 +571,10 @@ private MergeCombineAction( @Override protected void compute() { + if (cancellationGizmo.isCanceled()) { + cleanup(); + return; + } try { long start = System.nanoTime(); long startCpuNanos = JvmUtils.safeGetThreadCpuTime(); @@ -608,7 +633,7 @@ protected void compute() metricsAccumulator.incrementCpuTimeNanos(elapsedCpuNanos); metricsAccumulator.incrementTaskCount(); - if (!pQueue.isEmpty() && !cancellationGizmo.isCancelled()) { + if (!pQueue.isEmpty() && !cancellationGizmo.isCanceled()) { // if there is still work to be done, execute a new task with the current accumulated value to continue // combining where we left off if (!outputBatch.isDrained()) { @@ -650,29 +675,36 @@ protected void compute() metricsAccumulator, cancellationGizmo )); - } else if (cancellationGizmo.isCancelled()) { + } else if (cancellationGizmo.isCanceled()) { // if we got the cancellation signal, go ahead and write terminal value into output queue to help gracefully // allow downstream stuff to stop - LOG.debug("cancelled after %s tasks", metricsAccumulator.getTaskCount()); + LOG.debug("canceled after %s tasks", metricsAccumulator.getTaskCount()); // make sure to close underlying cursors - closeAllCursors(pQueue); - outputQueue.offer(ResultBatch.TERMINAL); + cleanup(); } else { // if priority queue is empty, push the final accumulated value into the output batch and push it out outputBatch.add(currentCombinedValue); metricsAccumulator.incrementOutputRows(batchCounter + 1L); outputQueue.offer(outputBatch); // ... and the terminal value to indicate the blocking queue holding the values is complete - outputQueue.offer(ResultBatch.TERMINAL); + outputQueue.offer(ResultBatch.terminal()); LOG.debug("merge combine complete after %s tasks", metricsAccumulator.getTaskCount()); } } catch (Throwable t) { - closeAllCursors(pQueue); cancellationGizmo.cancel(t); - outputQueue.offer(ResultBatch.TERMINAL); + cleanup(); } } + + private void cleanup() + { + closeAllCursors(pQueue); + // offer terminal result if queue is not full in case out is empty to allow downstream threads waiting on + // stuff to be present to stop blocking immediately. However, if the queue is full, it doesn't matter if we + // write anything because the cancellation signal has been set, which will also terminate processing. + outputQueue.offer(ResultBatch.terminal()); + } } @@ -696,7 +728,7 @@ private static class PrepareMergeCombineInputsAction extends RecursiveAction private final List> partition; private final Ordering orderingFn; private final BinaryOperator combineFn; - private final QueuePusher> outputQueue; + private final QueuePusher outputQueue; private final int yieldAfter; private final int batchSize; private final long targetTimeNanos; @@ -707,7 +739,7 @@ private static class PrepareMergeCombineInputsAction extends RecursiveAction private PrepareMergeCombineInputsAction( List> partition, - QueuePusher> outputQueue, + QueuePusher outputQueue, Ordering orderingFn, BinaryOperator combineFn, int yieldAfter, @@ -744,7 +776,7 @@ protected void compute() cursor.close(); } } - if (cursors.size() > 0) { + if (!cancellationGizmo.isCanceled() && !cursors.isEmpty()) { getPool().execute(new MergeCombineAction( cursors, outputQueue, @@ -758,14 +790,17 @@ protected void compute() cancellationGizmo )); } else { - outputQueue.offer(ResultBatch.TERMINAL); + outputQueue.offer(ResultBatch.terminal()); } metricsAccumulator.setPartitionInitializedTime(System.nanoTime() - startTime); } catch (Throwable t) { closeAllCursors(partition); cancellationGizmo.cancel(t); - outputQueue.offer(ResultBatch.TERMINAL); + // offer terminal result if queue is not full in case out is empty to allow downstream threads waiting on + // stuff to be present to stop blocking immediately. However, if the queue is full, it doesn't matter if we + // write anything because the cancellation signal has been set, which will also terminate processing. + outputQueue.tryOfferTerminal(); } } } @@ -779,12 +814,14 @@ static class QueuePusher implements ForkJoinPool.ManagedBlocker { final boolean hasTimeout; final long timeoutAtNanos; - final BlockingQueue queue; - volatile E item = null; + final BlockingQueue> queue; + final CancellationGizmo gizmo; + volatile ResultBatch item = null; - QueuePusher(BlockingQueue q, boolean hasTimeout, long timeoutAtNanos) + QueuePusher(BlockingQueue> q, CancellationGizmo gizmo, boolean hasTimeout, long timeoutAtNanos) { this.queue = q; + this.gizmo = gizmo; this.hasTimeout = hasTimeout; this.timeoutAtNanos = timeoutAtNanos; } @@ -795,14 +832,16 @@ public boolean block() throws InterruptedException boolean success = false; if (item != null) { if (hasTimeout) { - final long thisTimeoutNanos = timeoutAtNanos - System.nanoTime(); - if (thisTimeoutNanos < 0) { + final long remainingNanos = timeoutAtNanos - System.nanoTime(); + if (remainingNanos < 0) { item = null; - throw new QueryTimeoutException("QueuePusher timed out offering data"); + throw gizmo.cancelAndThrow(new QueryTimeoutException()); } - success = queue.offer(item, thisTimeoutNanos, TimeUnit.NANOSECONDS); + final long blockTimeoutNanos = Math.min(remainingNanos, BLOCK_TIMEOUT); + success = queue.offer(item, blockTimeoutNanos, TimeUnit.NANOSECONDS); } else { - success = queue.offer(item); + queue.put(item); + success = true; } if (success) { item = null; @@ -817,7 +856,7 @@ public boolean isReleasable() return item == null; } - public void offer(E item) + public void offer(ResultBatch item) { try { this.item = item; @@ -828,6 +867,11 @@ public void offer(E item) throw new RuntimeException("Failed to offer result to output queue", e); } } + + public void tryOfferTerminal() + { + this.queue.offer(ResultBatch.terminal()); + } } /** @@ -837,8 +881,10 @@ public void offer(E item) */ static class ResultBatch { - @SuppressWarnings("rawtypes") - static final ResultBatch TERMINAL = new ResultBatch(); + static ResultBatch terminal() + { + return new ResultBatch<>(); + } @Nullable private final Queue values; @@ -855,19 +901,16 @@ private ResultBatch() public void add(E in) { - assert values != null; values.offer(in); } public E get() { - assert values != null; return values.peek(); } public E next() { - assert values != null; return values.poll(); } @@ -925,6 +968,7 @@ static class SequenceBatcher implements ForkJoinPool.ManagedBlocker Yielder> getBatchYielder() { try { + batchYielder = null; ForkJoinPool.managedBlock(this); return batchYielder; } @@ -1033,8 +1077,8 @@ static class YielderBatchedResultsCursor extends BatchedResultsCursor @Override public void initialize() { - yielder = batcher.getBatchYielder(); - resultBatch = yielder.get(); + yielder = null; + nextBatch(); } @Override @@ -1059,6 +1103,10 @@ public boolean isDone() @Override public boolean block() { + if (yielder == null) { + yielder = batcher.getBatchYielder(); + resultBatch = yielder.get(); + } if (yielder.isDone()) { return true; } @@ -1073,7 +1121,7 @@ public boolean block() @Override public boolean isReleasable() { - return yielder.isDone() || (resultBatch != null && !resultBatch.isDrained()); + return (yielder != null && yielder.isDone()) || (resultBatch != null && !resultBatch.isDrained()); } @Override @@ -1092,11 +1140,13 @@ public void close() throws IOException static class BlockingQueueuBatchedResultsCursor extends BatchedResultsCursor { final BlockingQueue> queue; + final CancellationGizmo gizmo; final boolean hasTimeout; final long timeoutAtNanos; BlockingQueueuBatchedResultsCursor( BlockingQueue> blockingQueue, + CancellationGizmo cancellationGizmo, Ordering ordering, boolean hasTimeout, long timeoutAtNanos @@ -1104,6 +1154,7 @@ static class BlockingQueueuBatchedResultsCursor extends BatchedResultsCursor< { super(ordering); this.queue = blockingQueue; + this.gizmo = cancellationGizmo; this.hasTimeout = hasTimeout; this.timeoutAtNanos = timeoutAtNanos; } @@ -1142,17 +1193,18 @@ public boolean block() throws InterruptedException { if (resultBatch == null || resultBatch.isDrained()) { if (hasTimeout) { - final long thisTimeoutNanos = timeoutAtNanos - System.nanoTime(); - if (thisTimeoutNanos < 0) { - resultBatch = ResultBatch.TERMINAL; - throw new QueryTimeoutException("BlockingQueue cursor timed out waiting for data"); + final long remainingNanos = timeoutAtNanos - System.nanoTime(); + if (remainingNanos < 0) { + resultBatch = ResultBatch.terminal(); + throw gizmo.cancelAndThrow(new QueryTimeoutException()); } - resultBatch = queue.poll(thisTimeoutNanos, TimeUnit.NANOSECONDS); + final long blockTimeoutNanos = Math.min(remainingNanos, BLOCK_TIMEOUT); + resultBatch = queue.poll(blockTimeoutNanos, TimeUnit.NANOSECONDS); } else { resultBatch = queue.take(); } } - return resultBatch != null; + return resultBatch != null && !resultBatch.isDrained(); } @Override @@ -1164,35 +1216,91 @@ public boolean isReleasable() } // if we can get a result immediately without blocking, also no need to block resultBatch = queue.poll(); - return resultBatch != null; + return resultBatch != null && !resultBatch.isDrained(); } } /** - * Token to allow any {@link RecursiveAction} signal the others and the output sequence that something bad happened - * and processing should cancel, such as a timeout or connection loss. + * Token used to stop internal parallel processing across all tasks in the merge pool. Allows any + * {@link RecursiveAction} signal the others and the output sequence that something bad happened and + * processing should cancel, such as a timeout, error, or connection loss. */ - static class CancellationGizmo + public static class CancellationGizmo { private final AtomicReference throwable = new AtomicReference<>(null); + RuntimeException cancelAndThrow(Throwable t) + { + throwable.compareAndSet(null, t); + return wrapRuntimeException(t); + } + void cancel(Throwable t) { throwable.compareAndSet(null, t); } - boolean isCancelled() + boolean isCanceled() { return throwable.get() != null; } RuntimeException getRuntimeException() { - Throwable ex = throwable.get(); - if (ex instanceof RuntimeException) { - return (RuntimeException) ex; + return wrapRuntimeException(throwable.get()); + } + + private static RuntimeException wrapRuntimeException(Throwable t) + { + if (t instanceof RuntimeException) { + return (RuntimeException) t; } - return new RE(ex); + return new RuntimeException(t); + } + } + + /** + * {@link com.google.common.util.concurrent.ListenableFuture} that allows {@link ParallelMergeCombiningSequence} to be + * registered with {@link org.apache.druid.query.QueryWatcher#registerQueryFuture} to participate in query + * cancellation or anything else that has a need to watch the activity on the merge pool. Wraps a + * {@link CancellationGizmo} to allow for external threads to signal cancellation of parallel processing on the pool + * by triggering {@link CancellationGizmo#cancel(Throwable)} whenever {@link #cancel(boolean)} is called. + * + * This is not used internally by workers on the pool in favor of using the much simpler {@link CancellationGizmo} + * directly instead. + */ + public static class CancellationFuture extends AbstractFuture + { + private final CancellationGizmo cancellationGizmo; + + public CancellationFuture(CancellationGizmo cancellationGizmo) + { + this.cancellationGizmo = cancellationGizmo; + } + + public CancellationGizmo getCancellationGizmo() + { + return cancellationGizmo; + } + + @Override + public boolean set(Boolean value) + { + return super.set(value); + } + + @Override + public boolean setException(Throwable throwable) + { + cancellationGizmo.cancel(throwable); + return super.setException(throwable); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) + { + cancellationGizmo.cancel(new RuntimeException("Sequence canceled")); + return super.cancel(mayInterruptIfRunning); } } @@ -1308,8 +1416,8 @@ public long getSlowestPartitionInitializedTime() */ static class MergeCombineMetricsAccumulator { - List partitionMetrics; - MergeCombineActionMetricsAccumulator mergeMetrics; + List partitionMetrics = Collections.emptyList(); + MergeCombineActionMetricsAccumulator mergeMetrics = new MergeCombineActionMetricsAccumulator(); private long totalWallTime; @@ -1343,8 +1451,8 @@ MergeCombineMetrics build() // partition long totalPoolTasks = 1 + 1 + partitionMetrics.size(); - long fastestPartInitialized = partitionMetrics.size() > 0 ? Long.MAX_VALUE : mergeMetrics.getPartitionInitializedtime(); - long slowestPartInitialied = partitionMetrics.size() > 0 ? Long.MIN_VALUE : mergeMetrics.getPartitionInitializedtime(); + long fastestPartInitialized = !partitionMetrics.isEmpty() ? Long.MAX_VALUE : mergeMetrics.getPartitionInitializedtime(); + long slowestPartInitialied = !partitionMetrics.isEmpty() ? Long.MIN_VALUE : mergeMetrics.getPartitionInitializedtime(); // accumulate input row count, cpu time, and total number of tasks from each partition for (MergeCombineActionMetricsAccumulator partition : partitionMetrics) { diff --git a/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java b/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java index ca34c364dca84..5b76afb902294 100644 --- a/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java +++ b/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java @@ -143,7 +143,7 @@ public void testOrderedResultBatchFromSequenceBacktoYielderOnSequence() throws I if (!currentBatch.isDrained()) { outputQueue.offer(currentBatch); } - outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.TERMINAL); + outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.terminal()); rawYielder.close(); cursor.close(); @@ -211,16 +211,18 @@ public void testOrderedResultBatchFromSequenceToBlockingQueueCursor() throws IOE if (!currentBatch.isDrained()) { outputQueue.offer(currentBatch); } - outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.TERMINAL); + outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.terminal()); rawYielder.close(); cursor.close(); rawYielder = Yielders.each(rawSequence); + ParallelMergeCombiningSequence.CancellationGizmo gizmo = new ParallelMergeCombiningSequence.CancellationGizmo(); ParallelMergeCombiningSequence.BlockingQueueuBatchedResultsCursor queueCursor = new ParallelMergeCombiningSequence.BlockingQueueuBatchedResultsCursor<>( outputQueue, + gizmo, INT_PAIR_ORDERING, false, -1L @@ -551,14 +553,14 @@ public void testTimeoutExceptionDueToStalledInput() } @Test - public void testTimeoutExceptionDueToStalledReader() + public void testTimeoutExceptionDueToSlowReader() { - final int someSize = 2048; + final int someSize = 50_000; List> input = new ArrayList<>(); - input.add(nonBlockingSequence(someSize)); - input.add(nonBlockingSequence(someSize)); - input.add(nonBlockingSequence(someSize)); - input.add(nonBlockingSequence(someSize)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); Throwable t = Assert.assertThrows(QueryTimeoutException.class, () -> assertException(input, 8, 64, 1000, 1500)); Assert.assertEquals("Query did not complete within configured timeout period. " + @@ -567,6 +569,110 @@ public void testTimeoutExceptionDueToStalledReader() Assert.assertTrue(pool.isQuiescent()); } + @Test + public void testTimeoutExceptionDueToStoppedReader() throws InterruptedException + { + final int someSize = 150_000; + List reporters = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + List> input = new ArrayList<>(); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + + TestingReporter reporter = new TestingReporter(); + final ParallelMergeCombiningSequence parallelMergeCombineSequence = new ParallelMergeCombiningSequence<>( + pool, + input, + INT_PAIR_ORDERING, + INT_PAIR_MERGE_FN, + true, + 1000, + 0, + TEST_POOL_SIZE, + 512, + 128, + ParallelMergeCombiningSequence.DEFAULT_TASK_TARGET_RUN_TIME_MILLIS, + reporter + ); + Yielder parallelMergeCombineYielder = Yielders.each(parallelMergeCombineSequence); + reporter.future = parallelMergeCombineSequence.getCancellationFuture(); + reporter.yielder = parallelMergeCombineYielder; + reporter.yielder = parallelMergeCombineYielder.next(null); + Assert.assertFalse(parallelMergeCombineYielder.isDone()); + reporters.add(reporter); + } + + // sleep until timeout + Thread.sleep(1000); + Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS)); + Assert.assertTrue(pool.isQuiescent()); + Assert.assertFalse(pool.hasQueuedSubmissions()); + for (TestingReporter reporter : reporters) { + Assert.assertThrows(QueryTimeoutException.class, () -> reporter.yielder.next(null)); + Assert.assertTrue(reporter.future.isCancelled()); + Assert.assertTrue(reporter.future.getCancellationGizmo().isCanceled()); + } + Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS)); + Assert.assertTrue(pool.isQuiescent()); + } + + @Test + public void testManyBigSequencesAllAtOnce() throws IOException + { + final int someSize = 50_000; + List reporters = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + List> input = new ArrayList<>(); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + input.add(nonBlockingSequence(someSize, true)); + + TestingReporter reporter = new TestingReporter(); + final ParallelMergeCombiningSequence parallelMergeCombineSequence = new ParallelMergeCombiningSequence<>( + pool, + input, + INT_PAIR_ORDERING, + INT_PAIR_MERGE_FN, + true, + 30 * 1000, + 0, + TEST_POOL_SIZE, + 512, + 128, + ParallelMergeCombiningSequence.DEFAULT_TASK_TARGET_RUN_TIME_MILLIS, + reporter + ); + Yielder parallelMergeCombineYielder = Yielders.each(parallelMergeCombineSequence); + reporter.future = parallelMergeCombineSequence.getCancellationFuture(); + reporter.yielder = parallelMergeCombineYielder; + parallelMergeCombineYielder.next(null); + Assert.assertFalse(parallelMergeCombineYielder.isDone()); + reporters.add(reporter); + } + + for (TestingReporter testingReporter : reporters) { + Yielder parallelMergeCombineYielder = testingReporter.yielder; + while (!parallelMergeCombineYielder.isDone()) { + parallelMergeCombineYielder = parallelMergeCombineYielder.next(parallelMergeCombineYielder.get()); + } + Assert.assertTrue(parallelMergeCombineYielder.isDone()); + parallelMergeCombineYielder.close(); + Assert.assertTrue(testingReporter.future.isDone()); + } + + Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS)); + Assert.assertTrue(pool.isQuiescent()); + Assert.assertEquals(0, pool.getRunningThreadCount()); + Assert.assertFalse(pool.hasQueuedSubmissions()); + Assert.assertEquals(0, pool.getActiveThreadCount()); + for (TestingReporter reporter : reporters) { + Assert.assertTrue(reporter.done); + } + } + @Test public void testGracefulCloseOfYielderCancelsPool() throws IOException { @@ -666,7 +772,9 @@ private void assertResultWithCustomPool( parallelMergeCombineYielder.close(); // cancellation trigger should not be set if sequence was fully yielded and close is called // (though shouldn't actually matter even if it was...) - Assert.assertFalse(parallelMergeCombineSequence.getCancellationGizmo().isCancelled()); + Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().isCancelled()); + Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().isDone()); + Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().isCanceled()); } private void assertResult( @@ -713,13 +821,15 @@ private void assertResult( Assert.assertTrue(combiningYielder.isDone()); Assert.assertTrue(parallelMergeCombineYielder.isDone()); - Assert.assertTrue(pool.awaitQuiescence(1, TimeUnit.SECONDS)); + Assert.assertTrue(pool.awaitQuiescence(5, TimeUnit.SECONDS)); Assert.assertTrue(pool.isQuiescent()); combiningYielder.close(); parallelMergeCombineYielder.close(); // cancellation trigger should not be set if sequence was fully yielded and close is called // (though shouldn't actually matter even if it was...) - Assert.assertFalse(parallelMergeCombineSequence.getCancellationGizmo().isCancelled()); + Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().isCancelled()); + Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().isCanceled()); + Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().isDone()); } private void assertResultWithEarlyClose( @@ -773,20 +883,21 @@ private void assertResultWithEarlyClose( } } // trying to next the yielder creates sadness for you - final String expectedExceptionMsg = "Already closed"; + final String expectedExceptionMsg = "Sequence canceled"; Assert.assertEquals(combiningYielder.get(), parallelMergeCombineYielder.get()); final Yielder finalYielder = parallelMergeCombineYielder; Throwable t = Assert.assertThrows(RuntimeException.class, () -> finalYielder.next(finalYielder.get())); Assert.assertEquals(expectedExceptionMsg, t.getMessage()); // cancellation gizmo of sequence should be cancelled, and also should contain our expected message - Assert.assertTrue(parallelMergeCombineSequence.getCancellationGizmo().isCancelled()); + Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().isCanceled()); Assert.assertEquals( expectedExceptionMsg, - parallelMergeCombineSequence.getCancellationGizmo().getRuntimeException().getMessage() + parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().getRuntimeException().getMessage() ); + Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().isCancelled()); - Assert.assertTrue(pool.awaitQuiescence(1, TimeUnit.SECONDS)); + Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS)); Assert.assertTrue(pool.isQuiescent()); Assert.assertFalse(combiningYielder.isDone()); @@ -1082,4 +1193,19 @@ private static IntPair makeIntPair(int mergeKey) { return new IntPair(mergeKey, ThreadLocalRandom.current().nextInt(1, 100)); } + + static class TestingReporter implements Consumer + { + ParallelMergeCombiningSequence.CancellationFuture future; + Yielder yielder; + volatile ParallelMergeCombiningSequence.MergeCombineMetrics metrics; + volatile boolean done = false; + + @Override + public void accept(ParallelMergeCombiningSequence.MergeCombineMetrics mergeCombineMetrics) + { + metrics = mergeCombineMetrics; + done = true; + } + } } diff --git a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java index 5fa34d6699d84..e4027bcd3574c 100644 --- a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java +++ b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java @@ -384,7 +384,7 @@ private Sequence merge(List> sequencesByInterval) BinaryOperator mergeFn = toolChest.createMergeFn(query); final QueryContext queryContext = query.context(); if (parallelMergeConfig.useParallelMergePool() && queryContext.getEnableParallelMerges() && mergeFn != null) { - return new ParallelMergeCombiningSequence<>( + final ParallelMergeCombiningSequence parallelSequence = new ParallelMergeCombiningSequence<>( pool, sequencesByInterval, query.getResultOrdering(), @@ -414,6 +414,8 @@ private Sequence merge(List> sequencesByInterval) } } ); + scheduler.registerQueryFuture(query, parallelSequence.getCancellationFuture()); + return parallelSequence; } else { return Sequences .simple(sequencesByInterval)