From faacd2f480751292e23fb2c5621db73c3419730d Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 24 Jul 2024 03:23:14 -0700 Subject: [PATCH 01/13] MSQ worker: Support in-memory shuffles. This patch is a follow-up to #16168, adding worker-side support for in-memory shuffles. Changes include: 1) Worker-side code now respects the same context parameter "maxConcurrentStages" that was added to the controller in #16168. The parameter remains undocumented for now, to give us a chance to more fully develop and test this functionality. 1) WorkerImpl is broken up into WorkerImpl, RunWorkOrder, and RunWorkOrderListener to improve readability. 2) WorkerImpl has a new StageOutputHolder + StageOutputReader concept, which abstract over memory-based or file-based stage results. 3) RunWorkOrder is updated to create in-memory stage output channels when instructed to. 4) ControllerResource is updated to add /doneReadingInput/, so the controller can tell when workers that sort, but do not gather statistics, are done reading their inputs. 5) WorkerMemoryParameters is updated to consider maxConcurrentStages. Additionally, WorkerChatHandler is split into WorkerResource, so as to match ControllerChatHandler and ControllerResource. --- .../druid/msq/exec/ControllerClient.java | 23 +- .../exec/ListeningOutputChannelFactory.java | 74 + .../apache/druid/msq/exec/RunWorkOrder.java | 1045 +++++++++ .../druid/msq/exec/RunWorkOrderListener.java | 57 + .../org/apache/druid/msq/exec/Worker.java | 59 +- .../apache/druid/msq/exec/WorkerContext.java | 46 +- .../org/apache/druid/msq/exec/WorkerImpl.java | 2066 ++++++----------- .../msq/exec/WorkerMemoryParameters.java | 68 +- .../msq/exec/WorkerStorageParameters.java | 6 +- .../msq/indexing/IndexerFrameContext.java | 36 +- .../IndexerResourcePermissionMapper.java | 6 + .../msq/indexing/IndexerWorkerContext.java | 110 +- .../druid/msq/indexing/MSQWorkerTask.java | 19 +- .../client/IndexerControllerClient.java | 2 +- .../indexing/client/WorkerChatHandler.java | 313 +-- .../apache/druid/msq/input/InputSlices.java | 22 +- .../external/ExternalInputSliceReader.java | 13 +- .../apache/druid/msq/kernel/FrameContext.java | 20 +- .../msq/kernel/worker/WorkerStageKernel.java | 47 +- .../msq/kernel/worker/WorkerStagePhase.java | 27 +- .../druid/msq/rpc/BaseWorkerClientImpl.java | 4 +- .../druid/msq/rpc/ControllerResource.java | 24 +- .../druid/msq/rpc/MSQResourceUtils.java | 16 + .../msq/rpc/ResourcePermissionMapper.java | 6 +- .../apache/druid/msq/rpc/WorkerResource.java | 392 ++++ .../input/MetaInputChannelFactory.java | 115 + .../WorkerOrLocalInputChannelFactory.java | 70 + .../shuffle/output/ByteChunksInputStream.java | 103 + .../output/ChannelStageOutputReader.java | 237 ++ .../shuffle/output/FileStageOutputReader.java | 77 + .../output/FutureReadableFrameChannel.java | 125 + .../shuffle/output/NilStageOutputReader.java | 77 + .../msq/shuffle/output/StageOutputHolder.java | 118 + .../msq/shuffle/output/StageOutputReader.java | 55 + .../msq/exec/MSQDrillWindowQueryTest.java | 11 +- .../apache/druid/msq/exec/WorkerImplTest.java | 54 - .../msq/exec/WorkerMemoryParametersTest.java | 110 +- .../indexing/IndexerWorkerContextTest.java | 9 + .../msq/indexing/WorkerChatHandlerTest.java | 59 +- .../output/ByteChunksInputStreamTest.java | 149 ++ .../msq/test/CalciteArraysQueryMSQTest.java | 10 +- .../test/CalciteNestedDataQueryMSQTest.java | 10 +- .../test/CalciteSelectJoinQueryMSQTest.java | 10 +- .../msq/test/CalciteSelectQueryMSQTest.java | 10 +- .../msq/test/CalciteUnionQueryMSQTest.java | 10 +- .../apache/druid/msq/test/MSQTestBase.java | 24 +- .../msq/test/MSQTestControllerClient.java | 2 +- .../msq/test/MSQTestControllerContext.java | 19 +- .../druid/msq/test/MSQTestWorkerClient.java | 12 +- .../druid/msq/test/MSQTestWorkerContext.java | 183 +- .../channel/ReadableFileFrameChannel.java | 8 + 51 files changed, 4080 insertions(+), 2088 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java delete mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java index 405ff4fb9026..cbc3544c93ae 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java @@ -26,17 +26,21 @@ import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; +import java.io.Closeable; import java.io.IOException; import java.util.List; /** - * Client for the multi-stage query controller. Used by a Worker task. + * Client for the multi-stage query controller. Used by a {@link Worker}. Each instance is specific to a single query, + * meaning it communicates with a single controller. */ -public interface ControllerClient extends AutoCloseable +public interface ControllerClient extends Closeable { /** - * Client side method to update the controller with partial key statistics information for a particular stage and worker. - * Controller's implementation collates all the information for a stage to fetch key statistics from workers. + * Client side method to update the controller with partial key statistics information for a particular stage + * and worker. The controller collates all the information for a stage to fetch key statistics from workers. + * + * Only used when {@link StageDefinition#mustGatherResultKeyStatistics()}. */ void postPartialKeyStatistics( StageId stageId, @@ -77,20 +81,21 @@ void postResultsComplete( /** * Client side method to inform the controller that the error has occured in the given worker. + * + * @param queryId query ID, if this error is associated with a specific query + * @param errorWrapper error details */ void postWorkerError( - String workerId, + @Nullable String queryId, MSQErrorReport errorWrapper ) throws IOException; /** * Client side method to inform the controller about the warnings generated by the given worker. */ - void postWorkerWarning( - List MSQErrorReports - ) throws IOException; + void postWorkerWarning(List MSQErrorReports) throws IOException; - List getTaskList() throws IOException; + List getWorkerIds() throws IOException; /** * Close this client. Idempotent. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java new file mode 100644 index 000000000000..ebaad0763872 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.frame.processor.OutputChannelFactory; +import org.apache.druid.frame.processor.PartitionedOutputChannel; + +import java.io.IOException; + +/** + * Decorator for {@link OutputChannelFactory} that notifies a {@link Listener} whenever a channel is opened. + */ +public class ListeningOutputChannelFactory implements OutputChannelFactory +{ + private final OutputChannelFactory delegate; + private final Listener listener; + + public ListeningOutputChannelFactory(final OutputChannelFactory delegate, final Listener listener) + { + this.delegate = delegate; + this.listener = listener; + } + + @Override + public OutputChannel openChannel(final int partitionNumber) throws IOException + { + return notifyListener(delegate.openChannel(partitionNumber)); + } + + + @Override + public OutputChannel openNilChannel(final int partitionNumber) + { + return notifyListener(delegate.openNilChannel(partitionNumber)); + } + + @Override + public PartitionedOutputChannel openPartitionedChannel( + final String name, + final boolean deleteAfterRead + ) + { + throw new UnsupportedOperationException("Listening to partitioned channels is not supported"); + } + + private OutputChannel notifyListener(OutputChannel channel) + { + listener.channelOpened(channel); + return channel; + } + + public interface Listener + { + void channelOpened(OutputChannel channel); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java new file mode 100644 index 000000000000..45689652a646 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java @@ -0,0 +1,1045 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.frame.allocation.ArenaMemoryAllocator; +import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ByteTracker; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.frame.processor.ComposingOutputChannelFactory; +import org.apache.druid.frame.processor.FileOutputChannelFactory; +import org.apache.druid.frame.processor.FrameChannelHashPartitioner; +import org.apache.druid.frame.processor.FrameChannelMixer; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessorExecutor; +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.frame.processor.OutputChannelFactory; +import org.apache.druid.frame.processor.OutputChannels; +import org.apache.druid.frame.processor.PartitionedOutputChannel; +import org.apache.druid.frame.processor.SuperSorter; +import org.apache.druid.frame.processor.SuperSorterProgressTracker; +import org.apache.druid.frame.processor.manager.ProcessorManager; +import org.apache.druid.frame.processor.manager.ProcessorManagers; +import org.apache.druid.frame.util.DurableStorageUtils; +import org.apache.druid.frame.write.FrameWriters; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.UOE; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.msq.counters.CounterNames; +import org.apache.druid.msq.counters.CounterTracker; +import org.apache.druid.msq.indexing.CountingOutputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelsImpl; +import org.apache.druid.msq.indexing.processor.KeyStatisticsCollectionProcessor; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.InputSliceReader; +import org.apache.druid.msq.input.InputSlices; +import org.apache.druid.msq.input.MapInputSliceReader; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.NilInputSliceReader; +import org.apache.druid.msq.input.external.ExternalInputSlice; +import org.apache.druid.msq.input.external.ExternalInputSliceReader; +import org.apache.druid.msq.input.inline.InlineInputSlice; +import org.apache.druid.msq.input.inline.InlineInputSliceReader; +import org.apache.druid.msq.input.lookup.LookupInputSlice; +import org.apache.druid.msq.input.lookup.LookupInputSliceReader; +import org.apache.druid.msq.input.stage.InputChannels; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.input.stage.StageInputSliceReader; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.SegmentsInputSliceReader; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.FrameProcessorFactory; +import org.apache.druid.msq.kernel.ProcessorsAndChannels; +import org.apache.druid.msq.kernel.ShuffleSpec; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.shuffle.output.DurableStorageOutputChannelFactory; +import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.utils.CloseableUtils; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +/** + * Main worker logic for executing a {@link WorkOrder} in a {@link FrameProcessorExecutor}. + */ +public class RunWorkOrder +{ + private final WorkOrder workOrder; + private final InputChannelFactory inputChannelFactory; + private final CounterTracker counterTracker; + private final FrameProcessorExecutor exec; + private final String cancellationId; + private final int parallelism; + private final WorkerContext workerContext; + private final FrameContext frameContext; + private final RunWorkOrderListener listener; + private final boolean reindex; + private final boolean removeNullBytes; + private final ByteTracker intermediateSuperSorterLocalStorageTracker; + private final AtomicBoolean started = new AtomicBoolean(); + + private InputSliceReader inputSliceReader; + private OutputChannelFactory workOutputChannelFactory; + private OutputChannelFactory shuffleOutputChannelFactory; + private ResultAndChannels workResultAndOutputChannels; + private SettableFuture stagePartitionBoundariesFuture; + private ListenableFuture stageOutputChannelsFuture; + + public RunWorkOrder( + final WorkOrder workOrder, + final InputChannelFactory inputChannelFactory, + final CounterTracker counterTracker, + final FrameProcessorExecutor exec, + final String cancellationId, + final WorkerContext workerContext, + final FrameContext frameContext, + final RunWorkOrderListener listener, + final boolean reindex, + final boolean removeNullBytes + ) + { + this.workOrder = workOrder; + this.inputChannelFactory = inputChannelFactory; + this.counterTracker = counterTracker; + this.exec = exec; + this.cancellationId = cancellationId; + this.parallelism = workerContext.threadCount(); + this.workerContext = workerContext; + this.frameContext = frameContext; + this.listener = listener; + this.reindex = reindex; + this.removeNullBytes = removeNullBytes; + this.intermediateSuperSorterLocalStorageTracker = + new ByteTracker( + frameContext.storageParameters().isIntermediateStorageLimitConfigured() + ? frameContext.storageParameters().getIntermediateSuperSorterStorageMaxLocalBytes() + : Long.MAX_VALUE + ); + } + + /** + * Start execution of the provided {@link WorkOrder} in the provided {@link FrameProcessorExecutor}. + * + * Execution proceeds asynchronously after this method returns. The {@link RunWorkOrderListener} passed to the + * constructor of this instance can be used to track progress. + */ + public void start() throws IOException + { + if (started.getAndSet(true)) { + throw new ISE("Already started"); + } + + final StageDefinition stageDef = workOrder.getStageDefinition(); + + try { + makeInputSliceReader(); + makeWorkOutputChannelFactory(); + makeShuffleOutputChannelFactory(); + makeAndRunWorkProcessors(); + + if (stageDef.doesShuffle()) { + makeAndRunShuffleProcessors(); + } else { + // No shuffling: work output _is_ stage output. Retain read-only versions to reduce memory footprint. + stageOutputChannelsFuture = + Futures.immediateFuture(workResultAndOutputChannels.getOutputChannels().readOnly()); + } + + setUpCompletionCallbacks(); + } + catch (Throwable t) { + // If start() has problems, cancel anything that was already kicked off, and close the FrameContext. + try { + exec.cancel(cancellationId); + } + catch (Throwable t2) { + t.addSuppressed(t2); + } + + CloseableUtils.closeAndSuppressExceptions(frameContext, t::addSuppressed); + throw t; + } + } + + /** + * Settable {@link ClusterByPartitions} future for global sort. Necessary because we don't know ahead of time + * what the boundaries will be. The controller decides based on statistics from all workers. Once the controller + * decides, its decision is written to this future, which allows sorting on workers to proceed. + */ + @Nullable + public SettableFuture getStagePartitionBoundariesFuture() + { + return stagePartitionBoundariesFuture; + } + + private void makeInputSliceReader() + { + if (inputSliceReader != null) { + throw new ISE("inputSliceReader already created"); + } + + final String queryId = workOrder.getQueryDefinition().getQueryId(); + + final InputChannels inputChannels = + new InputChannelsImpl( + workOrder.getQueryDefinition(), + InputSlices.allReadablePartitions(workOrder.getInputs()), + inputChannelFactory, + () -> ArenaMemoryAllocator.createOnHeap(frameContext.memoryParameters().getStandardFrameSize()), + exec, + cancellationId, + removeNullBytes + ); + + inputSliceReader = new MapInputSliceReader( + ImmutableMap., InputSliceReader>builder() + .put(NilInputSlice.class, NilInputSliceReader.INSTANCE) + .put(StageInputSlice.class, new StageInputSliceReader(queryId, inputChannels)) + .put(ExternalInputSlice.class, new ExternalInputSliceReader(frameContext.tempDir("external"))) + .put(InlineInputSlice.class, new InlineInputSliceReader(frameContext.segmentWrangler())) + .put(LookupInputSlice.class, new LookupInputSliceReader(frameContext.segmentWrangler())) + .put(SegmentsInputSlice.class, new SegmentsInputSliceReader(frameContext, reindex)) + .build() + ); + } + + private void makeWorkOutputChannelFactory() + { + if (workOutputChannelFactory != null) { + throw new ISE("processorOutputChannelFactory already created"); + } + + final OutputChannelFactory baseOutputChannelFactory; + + if (workOrder.getStageDefinition().doesShuffle()) { + // Writing to a consumer in the same JVM (which will be set up later on in this method). Use the large frame + // size if we're writing to a SuperSorter, since we'll generate fewer temp files if we use larger frames. + // Otherwise, use the standard frame size. + final int frameSize; + + if (workOrder.getStageDefinition().getShuffleSpec().kind().isSort()) { + frameSize = frameContext.memoryParameters().getLargeFrameSize(); + } else { + frameSize = frameContext.memoryParameters().getStandardFrameSize(); + } + + baseOutputChannelFactory = new BlockingQueueOutputChannelFactory(frameSize); + } else { + // Writing stage output. + baseOutputChannelFactory = makeStageOutputChannelFactory(); + } + + workOutputChannelFactory = new CountingOutputChannelFactory( + baseOutputChannelFactory, + counterTracker.channel(CounterNames.outputChannel()) + ); + } + + private void makeShuffleOutputChannelFactory() + { + shuffleOutputChannelFactory = + new CountingOutputChannelFactory( + makeStageOutputChannelFactory(), + counterTracker.channel(CounterNames.shuffleChannel()) + ); + } + + /** + * Use {@link FrameProcessorFactory#makeProcessors} to create {@link ProcessorsAndChannels}. Executes the + * processors using {@link #exec} and sets the output channels in {@link #workResultAndOutputChannels}. + * + * @param type of {@link StageDefinition#getProcessorFactory()} + * @param return type of {@link FrameProcessor} created by the manager + * @param result type of {@link ProcessorManager#result()} + * @param type of {@link WorkOrder#getExtraInfo()} + */ + private , ProcessorReturnType, ManagerReturnType, ExtraInfoType> void makeAndRunWorkProcessors() + throws IOException + { + if (workResultAndOutputChannels != null) { + throw new ISE("workResultAndOutputChannels already set"); + } + + @SuppressWarnings("unchecked") + final FactoryType processorFactory = (FactoryType) workOrder.getStageDefinition().getProcessorFactory(); + + @SuppressWarnings("unchecked") + final ProcessorsAndChannels processors = + processorFactory.makeProcessors( + workOrder.getStageDefinition(), + workOrder.getWorkerNumber(), + workOrder.getInputs(), + inputSliceReader, + (ExtraInfoType) workOrder.getExtraInfo(), + workOutputChannelFactory, + frameContext, + parallelism, + counterTracker, + listener::onWarning, + removeNullBytes + ); + + final ProcessorManager processorManager = processors.getProcessorManager(); + + final int maxOutstandingProcessors; + + if (processors.getOutputChannels().getAllChannels().isEmpty()) { + // No output channels: run up to "parallelism" processors at once. + maxOutstandingProcessors = Math.max(1, parallelism); + } else { + // If there are output channels, that acts as a ceiling on the number of processors that can run at once. + maxOutstandingProcessors = + Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size())); + } + + final ListenableFuture workResultFuture = exec.runAllFully( + processorManager, + maxOutstandingProcessors, + frameContext.processorBouncer(), + cancellationId + ); + + workResultAndOutputChannels = new ResultAndChannels<>(workResultFuture, processors.getOutputChannels()); + } + + private void makeAndRunShuffleProcessors() + { + if (stageOutputChannelsFuture != null) { + throw new ISE("stageOutputChannelsFuture already set"); + } + + final ShuffleSpec shuffleSpec = workOrder.getStageDefinition().getShuffleSpec(); + + final ShufflePipelineBuilder shufflePipeline = new ShufflePipelineBuilder( + workOrder, + counterTracker, + exec, + cancellationId, + frameContext + ); + + shufflePipeline.initialize(workResultAndOutputChannels); + shufflePipeline.gatherResultKeyStatisticsAndReportDoneReadingInputIfNeeded(); + + switch (shuffleSpec.kind()) { + case MIX: + shufflePipeline.mix(shuffleOutputChannelFactory); + break; + + case HASH: + shufflePipeline.hashPartition(shuffleOutputChannelFactory); + break; + + case HASH_LOCAL_SORT: + final OutputChannelFactory hashOutputChannelFactory; + + if (shuffleSpec.partitionCount() == 1) { + // Single partition; no need to write temporary files. + hashOutputChannelFactory = + new BlockingQueueOutputChannelFactory(frameContext.memoryParameters().getStandardFrameSize()); + } else { + // Multi-partition; write temporary files and then sort each one file-by-file. + hashOutputChannelFactory = + new FileOutputChannelFactory( + frameContext.tempDir("hash-parts"), + frameContext.memoryParameters().getStandardFrameSize(), + null + ); + } + + shufflePipeline.hashPartition(hashOutputChannelFactory); + shufflePipeline.localSort(shuffleOutputChannelFactory); + break; + + case GLOBAL_SORT: + shufflePipeline.globalSort(shuffleOutputChannelFactory, makeGlobalSortPartitionBoundariesFuture()); + break; + + default: + throw new UOE("Cannot handle shuffle kind [%s]", shuffleSpec.kind()); + } + + stageOutputChannelsFuture = shufflePipeline.build(); + } + + private ListenableFuture makeGlobalSortPartitionBoundariesFuture() + { + if (workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { + if (stagePartitionBoundariesFuture != null) { + throw new ISE("Cannot call 'makeGlobalSortPartitionBoundariesFuture' twice"); + } + + return (stagePartitionBoundariesFuture = SettableFuture.create()); + } else { + // Result key stats aren't needed, so the partition boundaries are knowable ahead of time. Compute them now. + final ClusterByPartitions boundaries = + workOrder.getStageDefinition() + .generatePartitionBoundariesForShuffle(null) + .valueOrThrow(); + + return Futures.immediateFuture(boundaries); + } + } + + private void setUpCompletionCallbacks() + { + Futures.addCallback( + Futures.allAsList( + Arrays.asList( + workResultAndOutputChannels.getResultFuture(), + stageOutputChannelsFuture + ) + ), + new FutureCallback>() + { + @Override + public void onSuccess(final List workerResultAndOutputChannelsResolved) + { + final Object resultObject = workerResultAndOutputChannelsResolved.get(0); + final OutputChannels outputChannels = (OutputChannels) workerResultAndOutputChannelsResolved.get(1); + + if (workOrder.getOutputChannelMode() != OutputChannelMode.MEMORY) { + // In non-MEMORY output channel modes, call onOutputChannelAvailable when all work is done. + // (In MEMORY mode, we would have called onOutputChannelAvailable when the channels were created.) + for (final OutputChannel channel : outputChannels.getAllChannels()) { + listener.onOutputChannelAvailable(channel); + } + } + + if (workOrder.getOutputChannelMode().isDurable()) { + // In DURABLE_STORAGE output channel mode, write a success file once all work is done. + writeDurableStorageSuccessFile(); + } + + listener.onSuccess(resultObject); + } + + @Override + public void onFailure(final Throwable t) + { + listener.onFailure(t); + } + }, + Execs.directExecutor() + ); + } + + /** + * Write {@link DurableStorageUtils#SUCCESS_MARKER_FILENAME} for a particular stage, if durable storage is enabled. + */ + private void writeDurableStorageSuccessFile() + { + final DurableStorageOutputChannelFactory durableStorageOutputChannelFactory = + makeDurableStorageOutputChannelFactory( + frameContext.tempDir("durable"), + frameContext.memoryParameters().getStandardFrameSize(), + workOrder.getOutputChannelMode() == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS + ); + + try { + durableStorageOutputChannelFactory.createSuccessFile(workerContext.workerId()); + } + catch (IOException e) { + throw new ISE( + e, + "Unable to create success file at location[%s]", + DurableStorageUtils.SUCCESS_MARKER_FILENAME, + durableStorageOutputChannelFactory.getSuccessFilePath() + ); + } + } + + private OutputChannelFactory makeStageOutputChannelFactory() + { + // Use the standard frame size, since we assume this size when computing how much is needed to merge output + // files from different workers. + final int frameSize = frameContext.memoryParameters().getStandardFrameSize(); + final OutputChannelMode outputChannelMode = workOrder.getOutputChannelMode(); + + switch (outputChannelMode) { + case MEMORY: + // Use ListeningOutputChannelFactory to capture output channels as they are created, rather than when + // work is complete. + return new ListeningOutputChannelFactory( + new BlockingQueueOutputChannelFactory(frameSize), + listener::onOutputChannelAvailable + ); + + case LOCAL_STORAGE: + final File fileChannelDirectory = + frameContext.tempDir(StringUtils.format("output_stage_%06d", workOrder.getStageNumber())); + return new FileOutputChannelFactory(fileChannelDirectory, frameSize, null); + + case DURABLE_STORAGE_INTERMEDIATE: + case DURABLE_STORAGE_QUERY_RESULTS: + return makeDurableStorageOutputChannelFactory( + frameContext.tempDir("durable"), + frameSize, + outputChannelMode == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS + ); + + default: + throw DruidException.defensive("No handling for outputChannelMode[%s]", outputChannelMode); + } + } + + private OutputChannelFactory makeSuperSorterIntermediateOutputChannelFactory(final File tmpDir) + { + final int frameSize = frameContext.memoryParameters().getLargeFrameSize(); + final File fileChannelDirectory = + new File(tmpDir, StringUtils.format("intermediate_output_stage_%06d", workOrder.getStageNumber())); + final FileOutputChannelFactory fileOutputChannelFactory = + new FileOutputChannelFactory(fileChannelDirectory, frameSize, intermediateSuperSorterLocalStorageTracker); + + if (workOrder.getOutputChannelMode().isDurable() + && frameContext.storageParameters().isIntermediateStorageLimitConfigured()) { + final boolean isQueryResults = + workOrder.getOutputChannelMode() == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS; + return new ComposingOutputChannelFactory( + ImmutableList.of( + fileOutputChannelFactory, + makeDurableStorageOutputChannelFactory(tmpDir, frameSize, isQueryResults) + ), + frameSize + ); + } else { + return fileOutputChannelFactory; + } + } + + private DurableStorageOutputChannelFactory makeDurableStorageOutputChannelFactory( + final File tmpDir, + final int frameSize, + final boolean isQueryResults + ) + { + return DurableStorageOutputChannelFactory.createStandardImplementation( + workOrder.getQueryDefinition().getQueryId(), + workOrder.getWorkerNumber(), + workOrder.getStageNumber(), + workerContext.workerId(), + frameSize, + MSQTasks.makeStorageConnector(workerContext.injector()), + tmpDir, + isQueryResults + ); + } + + /** + * Helper for {@link RunWorkOrder#makeAndRunShuffleProcessors()}. Builds a {@link FrameProcessor} pipeline to + * handle the shuffle. + */ + private class ShufflePipelineBuilder + { + private final WorkOrder workOrder; + private final CounterTracker counterTracker; + private final FrameProcessorExecutor exec; + private final String cancellationId; + private final FrameContext frameContext; + + // Current state of the pipeline. It's a future to allow pipeline construction to be deferred if necessary. + private ListenableFuture> pipelineFuture; + + public ShufflePipelineBuilder( + final WorkOrder workOrder, + final CounterTracker counterTracker, + final FrameProcessorExecutor exec, + final String cancellationId, + final FrameContext frameContext + ) + { + this.workOrder = workOrder; + this.counterTracker = counterTracker; + this.exec = exec; + this.cancellationId = cancellationId; + this.frameContext = frameContext; + } + + /** + * Start the pipeline with the outputs of the main processor. + */ + public void initialize(final ResultAndChannels resultAndChannels) + { + if (pipelineFuture != null) { + throw new ISE("already initialized"); + } + + pipelineFuture = Futures.immediateFuture(resultAndChannels); + } + + /** + * Add {@link FrameChannelMixer}, which mixes all current outputs into a single channel from the provided factory. + */ + public void mix(final OutputChannelFactory outputChannelFactory) + { + // No sorting or statistics gathering, just combining all outputs into one big partition. Use a mixer to get + // everything into one file. Note: even if there is only one output channel, we'll run it through the mixer + // anyway, to ensure the data gets written to a file. (httpGetChannelData requires files.) + + push( + resultAndChannels -> { + final OutputChannel outputChannel = outputChannelFactory.openChannel(0); + + final FrameChannelMixer mixer = + new FrameChannelMixer( + resultAndChannels.getOutputChannels().getAllReadableChannels(), + outputChannel.getWritableChannel() + ); + + return new ResultAndChannels<>( + exec.runFully(mixer, cancellationId), + OutputChannels.wrap(Collections.singletonList(outputChannel.readOnly())) + ); + } + ); + } + + /** + * Add {@link KeyStatisticsCollectionProcessor} if {@link StageDefinition#mustGatherResultKeyStatistics()}. + * + * Calls {@link RunWorkOrderListener#onDoneReadingInput(ClusterByStatisticsSnapshot)} when statistics are gathered. + * If statistics were not needed, calls the listener immediately. + */ + public void gatherResultKeyStatisticsAndReportDoneReadingInputIfNeeded() + { + push( + resultAndChannels -> { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final OutputChannels channels = resultAndChannels.getOutputChannels(); + + if (channels.getAllChannels().isEmpty()) { + // No data coming out of this stage. Report empty statistics, if the kernel is expecting statistics. + if (stageDefinition.mustGatherResultKeyStatistics()) { + listener.onDoneReadingInput(ClusterByStatisticsSnapshot.empty()); + } else { + listener.onDoneReadingInput(null); + } + + // Generate one empty channel so the next part of the pipeline has something to do. + final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); + channel.writable().close(); + + final OutputChannel outputChannel = OutputChannel.readOnly( + channel.readable(), + FrameWithPartition.NO_PARTITION + ); + + return new ResultAndChannels<>( + Futures.immediateFuture(null), + OutputChannels.wrap(Collections.singletonList(outputChannel)) + ); + } else if (stageDefinition.mustGatherResultKeyStatistics()) { + return gatherResultKeyStatistics(channels); + } else { + // Report "done reading input" when the input future resolves. + // No need to add any processors to the pipeline. + resultAndChannels.resultFuture.addListener( + () -> listener.onDoneReadingInput(null), + Execs.directExecutor() + ); + return resultAndChannels; + } + } + ); + } + + /** + * Add a {@link SuperSorter} using {@link StageDefinition#getSortKey()} and partition boundaries + * from {@code partitionBoundariesFuture}. + */ + public void globalSort( + final OutputChannelFactory outputChannelFactory, + final ListenableFuture partitionBoundariesFuture + ) + { + pushAsync( + resultAndChannels -> { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + + final File sorterTmpDir = frameContext.tempDir("super-sort"); + FileUtils.mkdirp(sorterTmpDir); + if (!sorterTmpDir.isDirectory()) { + throw new IOException("Cannot create directory: " + sorterTmpDir); + } + + final WorkerMemoryParameters memoryParameters = frameContext.memoryParameters(); + final SuperSorter sorter = new SuperSorter( + resultAndChannels.getOutputChannels().getAllReadableChannels(), + stageDefinition.getFrameReader(), + stageDefinition.getSortKey(), + partitionBoundariesFuture, + exec, + outputChannelFactory, + makeSuperSorterIntermediateOutputChannelFactory(sorterTmpDir), + memoryParameters.getSuperSorterMaxActiveProcessors(), + memoryParameters.getSuperSorterMaxChannelsPerProcessor(), + -1, + cancellationId, + counterTracker.sortProgress(), + removeNullBytes + ); + + return FutureUtils.transform( + sorter.run(), + sortedChannels -> new ResultAndChannels<>(Futures.immediateFuture(null), sortedChannels) + ); + } + ); + } + + /** + * Add a {@link FrameChannelHashPartitioner} using {@link StageDefinition#getSortKey()}. + */ + public void hashPartition(final OutputChannelFactory outputChannelFactory) + { + pushAsync( + resultAndChannels -> { + final ShuffleSpec shuffleSpec = workOrder.getStageDefinition().getShuffleSpec(); + final int partitions = shuffleSpec.partitionCount(); + + final List outputChannels = new ArrayList<>(); + + for (int i = 0; i < partitions; i++) { + outputChannels.add(outputChannelFactory.openChannel(i)); + } + + final FrameChannelHashPartitioner partitioner = new FrameChannelHashPartitioner( + resultAndChannels.getOutputChannels().getAllReadableChannels(), + outputChannels.stream().map(OutputChannel::getWritableChannel).collect(Collectors.toList()), + workOrder.getStageDefinition().getFrameReader(), + workOrder.getStageDefinition().getClusterBy().getColumns().size(), + FrameWriters.makeRowBasedFrameWriterFactory( + new ArenaMemoryAllocatorFactory(frameContext.memoryParameters().getStandardFrameSize()), + workOrder.getStageDefinition().getSignature(), + workOrder.getStageDefinition().getSortKey(), + removeNullBytes + ) + ); + + final ListenableFuture partitionerFuture = exec.runFully(partitioner, cancellationId); + + final ResultAndChannels retVal = + new ResultAndChannels<>(partitionerFuture, OutputChannels.wrap(outputChannels)); + + if (retVal.getOutputChannels().areReadableChannelsReady()) { + return Futures.immediateFuture(retVal); + } else { + return FutureUtils.transform(partitionerFuture, ignored -> retVal); + } + } + ); + } + + /** + * Add a sequence of {@link SuperSorter}, operating on each current output channel in order, one at a time. + */ + public void localSort(final OutputChannelFactory outputChannelFactory) + { + pushAsync( + resultAndChannels -> { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final OutputChannels channels = resultAndChannels.getOutputChannels(); + final List> sortedChannelFutures = new ArrayList<>(); + + ListenableFuture nextFuture = Futures.immediateFuture(null); + + for (final OutputChannel channel : channels.getAllChannels()) { + final File sorterTmpDir = frameContext.tempDir( + StringUtils.format("hash-parts-super-sort-%06d", channel.getPartitionNumber()) + ); + + FileUtils.mkdirp(sorterTmpDir); + + // SuperSorter will try to write to output partition zero; we remap it to the correct partition number. + final OutputChannelFactory partitionOverrideOutputChannelFactory = new OutputChannelFactory() + { + @Override + public OutputChannel openChannel(int expectedZero) throws IOException + { + if (expectedZero != 0) { + throw new ISE("Unexpected part [%s]", expectedZero); + } + + return outputChannelFactory.openChannel(channel.getPartitionNumber()); + } + + @Override + public PartitionedOutputChannel openPartitionedChannel(String name, boolean deleteAfterRead) + { + throw new UnsupportedOperationException(); + } + + @Override + public OutputChannel openNilChannel(int expectedZero) + { + if (expectedZero != 0) { + throw new ISE("Unexpected part [%s]", expectedZero); + } + + return outputChannelFactory.openNilChannel(channel.getPartitionNumber()); + } + }; + + // Chain futures so we only sort one partition at a time. + nextFuture = Futures.transformAsync( + nextFuture, + ignored -> { + final SuperSorter sorter = new SuperSorter( + Collections.singletonList(channel.getReadableChannel()), + stageDefinition.getFrameReader(), + stageDefinition.getSortKey(), + Futures.immediateFuture(ClusterByPartitions.oneUniversalPartition()), + exec, + partitionOverrideOutputChannelFactory, + makeSuperSorterIntermediateOutputChannelFactory(sorterTmpDir), + 1, + 2, + -1, + cancellationId, + + // Tracker is not actually tracked, since it doesn't quite fit into the way we report counters. + // There's a single SuperSorterProgressTrackerCounter per worker, but workers that do local + // sorting have a SuperSorter per partition. + new SuperSorterProgressTracker(), + removeNullBytes + ); + + return FutureUtils.transform(sorter.run(), r -> Iterables.getOnlyElement(r.getAllChannels())); + }, + MoreExecutors.directExecutor() + ); + + sortedChannelFutures.add(nextFuture); + } + + return FutureUtils.transform( + Futures.allAsList(sortedChannelFutures), + sortedChannels -> new ResultAndChannels<>( + Futures.immediateFuture(null), + OutputChannels.wrap(sortedChannels) + ) + ); + } + ); + } + + /** + * Return the (future) output channels for this pipeline. + */ + public ListenableFuture build() + { + if (pipelineFuture == null) { + throw new ISE("Not initialized"); + } + + return Futures.transformAsync( + pipelineFuture, + resultAndChannels -> + Futures.transform( + resultAndChannels.getResultFuture(), + (Function) input -> { + sanityCheckOutputChannels(resultAndChannels.getOutputChannels()); + return resultAndChannels.getOutputChannels(); + }, + Execs.directExecutor() + ), + Execs.directExecutor() + ); + } + + /** + * Adds {@link KeyStatisticsCollectionProcessor}. Called by {@link #gatherResultKeyStatisticsAndReportDoneReadingInputIfNeeded()}. + */ + private ResultAndChannels gatherResultKeyStatistics(final OutputChannels channels) + { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final List retVal = new ArrayList<>(); + final List processors = new ArrayList<>(); + + for (final OutputChannel outputChannel : channels.getAllChannels()) { + final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); + retVal.add(OutputChannel.readOnly(channel.readable(), outputChannel.getPartitionNumber())); + + processors.add( + new KeyStatisticsCollectionProcessor( + outputChannel.getReadableChannel(), + channel.writable(), + stageDefinition.getFrameReader(), + stageDefinition.getClusterBy(), + stageDefinition.createResultKeyStatisticsCollector( + frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() + ) + ) + ); + } + + final ListenableFuture clusterByStatisticsCollectorFuture = + exec.runAllFully( + ProcessorManagers.of(processors) + .withAccumulation( + stageDefinition.createResultKeyStatisticsCollector( + frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() + ), + ClusterByStatisticsCollector::addAll + ), + // Run all processors simultaneously. They are lightweight and this keeps things moving. + processors.size(), + Bouncer.unlimited(), + cancellationId + ); + + Futures.addCallback( + clusterByStatisticsCollectorFuture, + new FutureCallback() + { + @Override + public void onSuccess(final ClusterByStatisticsCollector result) + { + listener.onDoneReadingInput(result.snapshot()); + } + + @Override + public void onFailure(Throwable t) + { + listener.onFailure( + new ISE(t, "Failed to gather clusterBy statistics for stage[%s]", stageDefinition.getId()) + ); + } + }, + Execs.directExecutor() + ); + + return new ResultAndChannels<>( + clusterByStatisticsCollectorFuture, + OutputChannels.wrap(retVal) + ); + } + + /** + * Update the {@link #pipelineFuture}. + */ + private void push(final ExceptionalFunction, ResultAndChannels> fn) + { + pushAsync( + channels -> + Futures.immediateFuture(fn.apply(channels)) + ); + } + + /** + * Update the {@link #pipelineFuture} asynchronously. + */ + private void pushAsync(final ExceptionalFunction, ListenableFuture>> fn) + { + if (pipelineFuture == null) { + throw new ISE("Not initialized"); + } + + pipelineFuture = FutureUtils.transform( + Futures.transformAsync( + pipelineFuture, + fn::apply, + Execs.directExecutor() + ), + resultAndChannels -> new ResultAndChannels<>( + resultAndChannels.getResultFuture(), + resultAndChannels.getOutputChannels().readOnly() + ) + ); + } + + /** + * Verifies there is exactly one channel per partition. + */ + private void sanityCheckOutputChannels(final OutputChannels outputChannels) + { + for (int partitionNumber : outputChannels.getPartitionNumbers()) { + final List outputChannelsForPartition = + outputChannels.getChannelsForPartition(partitionNumber); + + Preconditions.checkState(partitionNumber >= 0, "Expected partitionNumber >= 0, but got [%s]", partitionNumber); + Preconditions.checkState( + outputChannelsForPartition.size() == 1, + "Expected one channel for partition [%s], but got [%s]", + partitionNumber, + outputChannelsForPartition.size() + ); + } + } + } + + private static class ResultAndChannels + { + private final ListenableFuture resultFuture; + private final OutputChannels outputChannels; + + public ResultAndChannels( + ListenableFuture resultFuture, + OutputChannels outputChannels + ) + { + this.resultFuture = resultFuture; + this.outputChannels = outputChannels; + } + + public ListenableFuture getResultFuture() + { + return resultFuture; + } + + public OutputChannels getOutputChannels() + { + return outputChannels; + } + } + + private interface ExceptionalFunction + { + R apply(T t) throws Exception; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java new file mode 100644 index 000000000000..19c3c6570fe9 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; + +import javax.annotation.Nullable; + +/** + * Listener for various things that may happen during execution of {@link RunWorkOrder#start()}. Listener methods are + * fired in processing threads, so they must be thread-safe, and it is important that they run quickly. + */ +public interface RunWorkOrderListener +{ + /** + * Called when done reading input. If key statistics were gathered, they are provided. + */ + void onDoneReadingInput(@Nullable ClusterByStatisticsSnapshot snapshot); + + /** + * Called when an output channel becomes available for reading by downstream stages. + */ + void onOutputChannelAvailable(OutputChannel outputChannel); + + /** + * Called when the work order has succeeded. + */ + void onSuccess(Object resultObject); + + /** + * Called when a non-fatal exception is encountered. Work continues after this listener fires. + */ + void onWarning(Throwable t); + + /** + * Called when the work order has failed. + */ + void onFailure(Throwable t); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java index cc5f0fae1732..b068796cec70 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java @@ -19,40 +19,45 @@ package org.apache.druid.msq.exec; +import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.msq.counters.CounterSnapshotsTree; -import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import javax.annotation.Nullable; import java.io.IOException; import java.io.InputStream; +/** + * Interface for a multi-stage query (MSQ) worker. Workers are long-lived and are able to run multiple {@link WorkOrder} + * prior to exiting. + * + * @see WorkerImpl the production implementation + */ public interface Worker { /** - * Unique ID for this worker. + * Identifier for this worker. Same as {@link WorkerContext#workerId()}. */ String id(); /** - * The task which this worker runs. + * Runs the worker in the current thread. Surrounding classes provide the execution thread. */ - MSQWorkerTask task(); + void run(); /** - * Runs the worker in the current thread. Surrounding classes provide - * the execution thread. + * Terminate the worker upon a cancellation request. Causes a concurrently-running {@link #run()} method in + * a separate thread to cancel all outstanding work and exit. Does not block. Use {@link #awaitStop()} if you + * would like to wait for {@link #run()} to finish. */ - TaskStatus run() throws Exception; + void stop(); /** - * Terminate the worker upon a cancellation request. + * Wait for {@link #run()} to finish. */ - void stopGracefully(); + void awaitStop(); /** * Report that the controller has failed. The worker must cease work immediately. Cleanup then exit. @@ -63,20 +68,20 @@ public interface Worker // Controller-to-worker, and worker-to-worker messages /** - * Called when the worker chat handler receives a request for a work order. Accepts the work order and schedules it for - * execution + * Called when the worker receives a new work order. Accepts the work order and schedules it for execution. */ void postWorkOrder(WorkOrder workOrder); /** * Returns the statistics snapshot for the given stageId. This is called from {@link WorkerSketchFetcher} under - * PARALLEL OR AUTO modes. + * {@link ClusterStatisticsMergeMode#PARALLEL} OR {@link ClusterStatisticsMergeMode#AUTO} modes. */ ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId); /** * Returns the statistics snapshot for the given stageId which contains only the sketch for the specified timeChunk. - * This is called from {@link WorkerSketchFetcher} under SEQUENTIAL OR AUTO modes. + * This is called from {@link WorkerSketchFetcher} under {@link ClusterStatisticsMergeMode#SEQUENTIAL} or + * {@link ClusterStatisticsMergeMode#AUTO} modes. */ ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk); @@ -84,26 +89,28 @@ public interface Worker * Called when the worker chat handler recieves the result partition boundaries for a particular stageNumber * and queryId */ - boolean postResultPartitionBoundaries( - ClusterByPartitions stagePartitionBoundaries, - String queryId, - int stageNumber - ); + boolean postResultPartitionBoundaries(StageId stageId, ClusterByPartitions stagePartitionBoundaries); /** * Returns an InputStream of the worker output for a particular queryId, stageNumber and partitionNumber. * Offset indicates the number of bytes to skip the channel data, and is used to prevent re-reading the same data - * during retry in case of a connection error + * during retry in case of a connection error. + * + * The returned future resolves when at least one byte of data is available, or when the channel is finished. + * If the channel is finished, an empty {@link InputStream} is returned. + * + * With {@link OutputChannelMode#MEMORY}, once this method is called with a certain offset, workers are free to + * delete data prior to that offset. (It will not be re-requested.) * - * Returns a null if the workerOutput for a particular queryId, stageNumber, and partitionNumber is not found. + * Returns future that resolves to null if worker output for a particular queryId, stageNumber, and + * partitionNumber is not found. * * @throws IOException when the worker output is found but there is an error while reading it. */ - @Nullable - InputStream readChannel(String queryId, int stageNumber, int partitionNumber, long offset) throws IOException; + ListenableFuture readChannel(StageId stageId, int partitionNumber, long offset) throws IOException; /** - * Returns the snapshot of the worker counters + * Returns a snapshot of counters. */ CounterSnapshotsTree getCounters(); @@ -115,7 +122,7 @@ boolean postResultPartitionBoundaries( void postCleanupStage(StageId stageId); /** - * Called when the work required for the query has been finished + * Called when the worker is no longer needed, and should shut down. */ void postFinish(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java index f5e86039c23f..666115d774cf 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java @@ -21,11 +21,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; -import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.FrameProcessorFactory; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.server.DruidNode; import java.io.File; @@ -33,10 +34,21 @@ /** * Context used by multi-stage query workers. * - * Useful because it allows test fixtures to provide their own implementations. + * Each context is scoped to a {@link Worker} and is shared across all {@link WorkOrder} run by that worker. */ public interface WorkerContext { + /** + * Query ID for this context. + */ + String queryId(); + + /** + * Identifier for this worker that enables the controller, and other workers, to find it. For tasks this is the + * task ID from {@link MSQWorkerTask#getId()}. For persistent servers, this is the server URI. + */ + String workerId(); + ObjectMapper jsonMapper(); // Using an Injector directly because tasks do not have a way to provide their own Guice modules. @@ -49,9 +61,15 @@ public interface WorkerContext void registerWorker(Worker worker, Closer closer); /** - * Creates and fetches the controller client for the provided controller ID. + * Maximum number of {@link WorkOrder} that a {@link Worker} with this context will be asked to execute + * simultaneously. + */ + int maxConcurrentStages(); + + /** + * Creates a controller client. */ - ControllerClient makeControllerClient(String controllerId); + ControllerClient makeControllerClient(); /** * Creates and fetches a {@link WorkerClient}. It is independent of the workerId because the workerId is passed @@ -60,24 +78,24 @@ public interface WorkerContext WorkerClient makeWorkerClient(); /** - * Fetch a directory for temporary outputs + * Directory for temporary outputs. */ File tempDir(); - FrameContext frameContext(QueryDefinition queryDef, int stageNumber); + /** + * Create a context with useful objects required by {@link FrameProcessorFactory#makeProcessors}. + */ + FrameContext frameContext(QueryDefinition queryDef, int stageNumber, OutputChannelMode outputChannelMode); + /** + * Number of available processing threads. + */ int threadCount(); /** - * Fetch node info about self + * Fetch node info about self. */ DruidNode selfNode(); - Bouncer processorBouncer(); DataServerQueryHandlerFactory dataServerQueryHandlerFactory(); - - default File tempDir(int stageNumber, String id) - { - return new File(StringUtils.format("%s/stage_%02d/%s", tempDir(), stageNumber, id)); - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 61939d823731..91e003361628 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -19,111 +19,58 @@ package org.apache.druid.msq.exec; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Function; -import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import com.google.common.util.concurrent.AsyncFunction; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; -import it.unimi.dsi.fastutil.bytes.ByteArrays; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntObjectPair; import org.apache.druid.common.guava.FutureUtils; -import org.apache.druid.frame.allocation.ArenaMemoryAllocator; -import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory; -import org.apache.druid.frame.channel.BlockingQueueFrameChannel; -import org.apache.druid.frame.channel.ByteTracker; -import org.apache.druid.frame.channel.FrameWithPartition; -import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.channel.ReadableFrameChannel; -import org.apache.druid.frame.channel.ReadableNilFrameChannel; -import org.apache.druid.frame.file.FrameFile; -import org.apache.druid.frame.file.FrameFileWriter; import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory; -import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.frame.processor.ComposingOutputChannelFactory; -import org.apache.druid.frame.processor.FileOutputChannelFactory; -import org.apache.druid.frame.processor.FrameChannelHashPartitioner; -import org.apache.druid.frame.processor.FrameChannelMixer; -import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessorExecutor; import org.apache.druid.frame.processor.OutputChannel; -import org.apache.druid.frame.processor.OutputChannelFactory; -import org.apache.druid.frame.processor.OutputChannels; -import org.apache.druid.frame.processor.PartitionedOutputChannel; -import org.apache.druid.frame.processor.SuperSorter; -import org.apache.druid.frame.processor.SuperSorterProgressTracker; -import org.apache.druid.frame.processor.manager.ProcessorManager; -import org.apache.druid.frame.processor.manager.ProcessorManagers; import org.apache.druid.frame.util.DurableStorageUtils; -import org.apache.druid.frame.write.FrameWriters; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.ISE; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.counters.CounterNames; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.counters.CounterTracker; -import org.apache.druid.msq.indexing.CountingOutputChannelFactory; import org.apache.druid.msq.indexing.InputChannelFactory; -import org.apache.druid.msq.indexing.InputChannelsImpl; import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.destination.MSQSelectDestination; import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.MSQException; -import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher; import org.apache.druid.msq.indexing.error.MSQWarningReportPublisher; import org.apache.druid.msq.indexing.error.MSQWarningReportSimplePublisher; import org.apache.druid.msq.indexing.error.MSQWarnings; -import org.apache.druid.msq.indexing.processor.KeyStatisticsCollectionProcessor; -import org.apache.druid.msq.input.InputSlice; -import org.apache.druid.msq.input.InputSliceReader; import org.apache.druid.msq.input.InputSlices; -import org.apache.druid.msq.input.MapInputSliceReader; -import org.apache.druid.msq.input.NilInputSlice; -import org.apache.druid.msq.input.NilInputSliceReader; -import org.apache.druid.msq.input.external.ExternalInputSlice; -import org.apache.druid.msq.input.external.ExternalInputSliceReader; -import org.apache.druid.msq.input.inline.InlineInputSlice; -import org.apache.druid.msq.input.inline.InlineInputSliceReader; -import org.apache.druid.msq.input.lookup.LookupInputSlice; -import org.apache.druid.msq.input.lookup.LookupInputSliceReader; -import org.apache.druid.msq.input.stage.InputChannels; import org.apache.druid.msq.input.stage.ReadablePartition; -import org.apache.druid.msq.input.stage.StageInputSlice; -import org.apache.druid.msq.input.stage.StageInputSliceReader; -import org.apache.druid.msq.input.table.SegmentsInputSlice; -import org.apache.druid.msq.input.table.SegmentsInputSliceReader; import org.apache.druid.msq.kernel.FrameContext; -import org.apache.druid.msq.kernel.FrameProcessorFactory; -import org.apache.druid.msq.kernel.ProcessorsAndChannels; -import org.apache.druid.msq.kernel.ShuffleSpec; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.kernel.StagePartition; import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelUtils; import org.apache.druid.msq.kernel.worker.WorkerStageKernel; import org.apache.druid.msq.kernel.worker.WorkerStagePhase; import org.apache.druid.msq.shuffle.input.DurableStorageInputChannelFactory; +import org.apache.druid.msq.shuffle.input.MetaInputChannelFactory; import org.apache.druid.msq.shuffle.input.WorkerInputChannelFactory; -import org.apache.druid.msq.shuffle.output.DurableStorageOutputChannelFactory; -import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; +import org.apache.druid.msq.shuffle.input.WorkerOrLocalInputChannelFactory; +import org.apache.druid.msq.shuffle.output.StageOutputHolder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import org.apache.druid.msq.util.DecoratedExecutorService; @@ -132,23 +79,14 @@ import org.apache.druid.query.PrioritizedRunnable; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; -import org.apache.druid.rpc.ServiceClosedException; import org.apache.druid.server.DruidNode; import javax.annotation.Nullable; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.File; +import java.io.Closeable; import java.io.IOException; import java.io.InputStream; -import java.io.RandomAccessFile; -import java.nio.channels.Channels; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -156,7 +94,6 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -166,105 +103,84 @@ /** * Interface for a worker of a multi-stage query. + * + * Not scoped to any particular query. There is one of these per {@link MSQWorkerTask}, and one per server for + * long-lived workers. */ public class WorkerImpl implements Worker { private static final Logger log = new Logger(WorkerImpl.class); + /** + * Task object, if this {@link WorkerImpl} was launched from a task. Ideally, this would not be needed, and we + * would be able to get everything we need from {@link WorkerContext}. + */ + @Nullable private final MSQWorkerTask task; private final WorkerContext context; private final DruidNode selfDruidNode; - private final Bouncer processorBouncer; - private final BlockingQueue> kernelManipulationQueue = new LinkedBlockingDeque<>(); - private final ConcurrentHashMap> stageOutputs = new ConcurrentHashMap<>(); - private final ConcurrentHashMap stageCounters = new ConcurrentHashMap<>(); - private final ConcurrentHashMap stageKernelMap = new ConcurrentHashMap<>(); - private final ByteTracker intermediateSuperSorterLocalStorageTracker; - private final boolean durableStageStorageEnabled; - private final WorkerStorageParameters workerStorageParameters; - private final boolean isRemoveNullBytes; + private final BlockingQueue> kernelManipulationQueue = new LinkedBlockingDeque<>(); + private final ConcurrentHashMap> stageOutputs = new ConcurrentHashMap<>(); /** - * Only set for select jobs. + * Pair of {workerNumber, stageId} -> counters. */ - @Nullable - private final MSQSelectDestination selectDestination; + private final ConcurrentHashMap, CounterTracker> stageCounters = new ConcurrentHashMap<>(); + + /** + * Future that resolves when {@link #run()} completes. + */ + private final SettableFuture runFuture = SettableFuture.create(); /** - * Set once in {@link #runTask} and never reassigned. + * Set once in {@link #run} and never reassigned. This is in a field so {@link #doCancel()} can close it. */ private volatile ControllerClient controllerClient; /** - * Set once in {@link #runTask} and never reassigned. Used by processing threads so we can contact other workers + * Set once in {@link #runInternal} and never reassigned. Used by processing threads so we can contact other workers * during a shuffle. */ private volatile WorkerClient workerClient; /** - * Set to false by {@link #controllerFailed()} as a way of enticing the {@link #runTask} method to exit promptly. + * Set to false by {@link #controllerFailed()} as a way of enticing the {@link #runInternal} method to exit promptly. */ private volatile boolean controllerAlive = true; - public WorkerImpl(MSQWorkerTask task, WorkerContext context) - { - this( - task, - context, - WorkerStorageParameters.createProductionInstance( - context.injector(), - MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(task.getContext())) - // If Durable Storage is enabled, then super sorter intermediate storage can be enabled. - ) - ); - } - - @VisibleForTesting - public WorkerImpl(MSQWorkerTask task, WorkerContext context, WorkerStorageParameters workerStorageParameters) + public WorkerImpl(@Nullable final MSQWorkerTask task, final WorkerContext context) { this.task = task; this.context = context; this.selfDruidNode = context.selfNode(); - this.processorBouncer = context.processorBouncer(); - QueryContext queryContext = QueryContext.of(task.getContext()); - this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext); - this.selectDestination = MultiStageQueryContext.getSelectDestinationOrNull(queryContext); - this.isRemoveNullBytes = MultiStageQueryContext.removeNullBytes(queryContext); - this.workerStorageParameters = workerStorageParameters; - - long maxBytes = workerStorageParameters.isIntermediateStorageLimitConfigured() - ? workerStorageParameters.getIntermediateSuperSorterStorageMaxLocalBytes() - : Long.MAX_VALUE; - this.intermediateSuperSorterLocalStorageTracker = new ByteTracker(maxBytes); } @Override public String id() { - return task.getId(); - } - - @Override - public MSQWorkerTask task() - { - return task; + return context.workerId(); } @Override - public TaskStatus run() throws Exception + public void run() { try (final Closer closer = Closer.create()) { + final KernelHolders kernelHolders = KernelHolders.create(context, closer); + controllerClient = kernelHolders.getControllerClient(); + + Throwable t = null; Optional maybeErrorReport; try { - maybeErrorReport = runTask(closer); + maybeErrorReport = runInternal(kernelHolders, closer); } catch (Throwable e) { + t = e; maybeErrorReport = Optional.of( MSQErrorReport.fromException( - id(), - MSQTasks.getHostFromSelfNode(selfDruidNode), + context.workerId(), + MSQTasks.getHostFromSelfNode(context.selfNode()), null, e ) @@ -273,203 +189,112 @@ public TaskStatus run() throws Exception if (maybeErrorReport.isPresent()) { final MSQErrorReport errorReport = maybeErrorReport.get(); - final String errorLogMessage = MSQTasks.errorReportToLogMessage(errorReport); - log.warn(errorLogMessage); + final String logMessage = MSQTasks.errorReportToLogMessage(errorReport); + log.warn("%s", logMessage); - closer.register(() -> { - if (controllerAlive && controllerClient != null && selfDruidNode != null) { - controllerClient.postWorkerError(id(), errorReport); - } - }); + if (controllerAlive) { + controllerClient.postWorkerError(context.queryId(), errorReport); + } - return TaskStatus.failure(id(), MSQFaultUtils.generateMessageWithErrorCode(errorReport.getFault())); - } else { - return TaskStatus.success(id()); + if (t != null) { + Throwables.throwIfInstanceOf(t, MSQException.class); + throw new MSQException(t, maybeErrorReport.get().getFault()); + } else { + throw new MSQException(maybeErrorReport.get().getFault()); + } } } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + runFuture.set(null); + } } /** * Runs worker logic. Returns an empty Optional on success. On failure, returns an error report for errors that * happened in other threads; throws exceptions for errors that happened in the main worker loop. */ - public Optional runTask(final Closer closer) throws Exception + private Optional runInternal(final KernelHolders kernelHolders, final Closer workerCloser) + throws Exception { - this.controllerClient = context.makeControllerClient(task.getControllerTaskId()); - closer.register(controllerClient::close); - closer.register(context.dataServerQueryHandlerFactory()); - context.registerWorker(this, closer); // Uses controllerClient, so must be called after that is initialized - - this.workerClient = new ExceptionWrappingWorkerClient(context.makeWorkerClient()); - closer.register(workerClient::close); - - final KernelHolder kernelHolder = new KernelHolder(); - final String cancellationId = id(); - + context.registerWorker(this, workerCloser); + workerCloser.register(context.dataServerQueryHandlerFactory()); + this.workerClient = workerCloser.register(new ExceptionWrappingWorkerClient(context.makeWorkerClient())); final FrameProcessorExecutor workerExec = new FrameProcessorExecutor(makeProcessingPool()); - // Delete all the stage outputs - closer.register(() -> { - for (final StageId stageId : stageOutputs.keySet()) { - cleanStageOutput(stageId, false); - } - }); - - // Close stage output processors and running futures (if present) - closer.register(() -> { - try { - workerExec.cancel(cancellationId); - } - catch (InterruptedException e) { - // Strange that cancellation would itself be interrupted. Throw an exception, since this is unexpected. - throw new RuntimeException(e); - } - }); + final long maxAllowedParseExceptions; - long maxAllowedParseExceptions = Long.parseLong(task.getContext().getOrDefault( - MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, - Long.MAX_VALUE - ).toString()); + if (task != null) { + maxAllowedParseExceptions = + Long.parseLong(task.getContext() + .getOrDefault(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, Long.MAX_VALUE) + .toString()); + } else { + maxAllowedParseExceptions = 0; + } - long maxVerboseParseExceptions; + final long maxVerboseParseExceptions; if (maxAllowedParseExceptions == -1L) { maxVerboseParseExceptions = Limits.MAX_VERBOSE_PARSE_EXCEPTIONS; } else { maxVerboseParseExceptions = Math.min(maxAllowedParseExceptions, Limits.MAX_VERBOSE_PARSE_EXCEPTIONS); } - Set criticalWarningCodes; + final Set criticalWarningCodes; if (maxAllowedParseExceptions == 0) { criticalWarningCodes = ImmutableSet.of(CannotParseExternalDataFault.CODE); } else { criticalWarningCodes = ImmutableSet.of(); } - final MSQWarningReportPublisher msqWarningReportPublisher = new MSQWarningReportLimiterPublisher( - new MSQWarningReportSimplePublisher( - id(), - controllerClient, - id(), - MSQTasks.getHostFromSelfNode(selfDruidNode) - ), - Limits.MAX_VERBOSE_WARNINGS, - ImmutableMap.of(CannotParseExternalDataFault.CODE, maxVerboseParseExceptions), - criticalWarningCodes, - controllerClient, - id(), - MSQTasks.getHostFromSelfNode(selfDruidNode) - ); - - closer.register(msqWarningReportPublisher); + // Delay removal of kernels so we don't interfere with iteration of kernelHolders.getAllKernelHolders(). + final Set kernelsToRemove = new HashSet<>(); - final Map> partitionBoundariesFutureMap = new HashMap<>(); - - final Map stageFrameContexts = new HashMap<>(); - - while (!kernelHolder.isDone()) { + while (!kernelHolders.isDone()) { boolean didSomething = false; - for (final WorkerStageKernel kernel : kernelHolder.getStageKernelMap().values()) { + for (final KernelHolder kernelHolder : kernelHolders.getAllKernelHolders()) { + final WorkerStageKernel kernel = kernelHolder.kernel; final StageDefinition stageDefinition = kernel.getStageDefinition(); - if (kernel.getPhase() == WorkerStagePhase.NEW) { - - log.info("Processing work order for stage [%d]" + - (log.isDebugEnabled() - ? StringUtils.format( - " with payload [%s]", - context.jsonMapper().writeValueAsString(kernel.getWorkOrder()) - ) : ""), stageDefinition.getId().getStageNumber()); - - // Create separate inputChannelFactory per stage, because the list of tasks can grow between stages, and - // so we need to avoid the memoization in baseInputChannelFactory. - final InputChannelFactory inputChannelFactory = makeBaseInputChannelFactory(closer); - - // Compute memory parameters for all stages, even ones that haven't been assigned yet, so we can fail-fast - // if some won't work. (We expect that all stages will get assigned to the same pool of workers.) - for (final StageDefinition stageDef : kernel.getWorkOrder().getQueryDefinition().getStageDefinitions()) { - stageFrameContexts.computeIfAbsent( - stageDef.getId(), - stageId -> context.frameContext( - kernel.getWorkOrder().getQueryDefinition(), - stageId.getStageNumber() - ) - ); - } - - // Start working on this stage immediately. - kernel.startReading(); - - final RunWorkOrder runWorkOrder = new RunWorkOrder( - kernel, - inputChannelFactory, - stageCounters.computeIfAbsent(stageDefinition.getId(), ignored -> new CounterTracker()), + // Workers run all work orders they get. There is not (currently) any limit on the number of concurrent work + // orders; we rely on the controller to avoid overloading workers. + if (kernel.getPhase() == WorkerStagePhase.NEW + && kernelHolders.runningKernelCount() < context.maxConcurrentStages()) { + handleNewWorkOrder( + kernelHolder, + controllerClient, workerExec, - cancellationId, - context.threadCount(), - stageFrameContexts.get(stageDefinition.getId()), - msqWarningReportPublisher + criticalWarningCodes, + maxVerboseParseExceptions ); - - runWorkOrder.start(); - - final SettableFuture partitionBoundariesFuture = - runWorkOrder.getStagePartitionBoundariesFuture(); - - if (partitionBoundariesFuture != null) { - if (partitionBoundariesFutureMap.put(stageDefinition.getId(), partitionBoundariesFuture) != null) { - throw new ISE("Work order collision for stage [%s]", stageDefinition.getId()); - } - } - + logKernelStatus(kernelHolders.getAllKernels()); didSomething = true; - logKernelStatus(kernelHolder.getStageKernelMap().values()); } - if (kernel.getPhase() == WorkerStagePhase.READING_INPUT && kernel.hasResultKeyStatisticsSnapshot()) { - if (controllerAlive) { - PartialKeyStatisticsInformation partialKeyStatisticsInformation = - kernel.getResultKeyStatisticsSnapshot() - .partialKeyStatistics(); - - controllerClient.postPartialKeyStatistics( - stageDefinition.getId(), - kernel.getWorkOrder().getWorkerNumber(), - partialKeyStatisticsInformation - ); - } - kernel.startPreshuffleWaitingForResultPartitionBoundaries(); - + if (kernel.getPhase() == WorkerStagePhase.READING_INPUT + && handleReadingInput(kernelHolder, controllerClient)) { didSomething = true; - logKernelStatus(kernelHolder.getStageKernelMap().values()); + logKernelStatus(kernelHolders.getAllKernels()); } - logKernelStatus(kernelHolder.getStageKernelMap().values()); if (kernel.getPhase() == WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES - && kernel.hasResultPartitionBoundaries()) { - partitionBoundariesFutureMap.get(stageDefinition.getId()).set(kernel.getResultPartitionBoundaries()); - kernel.startPreshuffleWritingOutput(); - + && handleWaitingForResultPartitionBoundaries(kernelHolder)) { didSomething = true; - logKernelStatus(kernelHolder.getStageKernelMap().values()); + logKernelStatus(kernelHolders.getAllKernels()); } - if (kernel.getPhase() == WorkerStagePhase.RESULTS_READY - && kernel.addPostedResultsComplete(Pair.of( - stageDefinition.getId(), - kernel.getWorkOrder().getWorkerNumber() - ))) { - if (controllerAlive) { - controllerClient.postResultsComplete( - stageDefinition.getId(), - kernel.getWorkOrder().getWorkerNumber(), - kernel.getResultObject() - ); - } + if (kernel.getPhase() == WorkerStagePhase.RESULTS_COMPLETE + && handleResultsReady(kernelHolder, controllerClient)) { + didSomething = true; + logKernelStatus(kernelHolders.getAllKernels()); } if (kernel.getPhase() == WorkerStagePhase.FAILED) { - // Better than throwing an exception, because we can include the stage number. + // Return an error report when a work order fails. This is better than throwing an exception, because we can + // include the stage number. return Optional.of( MSQErrorReport.fromException( id(), @@ -479,17 +304,37 @@ public Optional runTask(final Closer closer) throws Exception ) ); } + + if (kernel.getPhase().isTerminal()) { + handleTerminated(kernelHolder); + kernelsToRemove.add(stageDefinition.getId()); + } + } + + for (final StageId stageId : kernelsToRemove) { + kernelHolders.removeKernel(stageId); } - if (!didSomething && !kernelHolder.isDone()) { - Consumer nextCommand; + kernelsToRemove.clear(); + + if (!didSomething && !kernelHolders.isDone()) { + Consumer nextCommand; + // Run the next command, waiting for it if necessary. Post counters to the controller every 5 seconds + // while waiting. do { - postCountersToController(); + postCountersToController(kernelHolders.getControllerClient()); } while ((nextCommand = kernelManipulationQueue.poll(5, TimeUnit.SECONDS)) == null); - nextCommand.accept(kernelHolder); - logKernelStatus(kernelHolder.getStageKernelMap().values()); + nextCommand.accept(kernelHolders); + + // Run all pending commands after that one. Helps avoid deep queues. + // After draining the command queue, move on to the next iteration of the worker loop. + while ((nextCommand = kernelManipulationQueue.poll()) != null) { + nextCommand.accept(kernelHolders); + } + + logKernelStatus(kernelHolders.getAllKernels()); } } @@ -497,123 +342,282 @@ public Optional runTask(final Closer closer) throws Exception return Optional.empty(); } + /** + * Handle a kernel in state {@link WorkerStagePhase#NEW}. The kernel is transitioned to + * {@link WorkerStagePhase#READING_INPUT} and a {@link RunWorkOrder} instance is created to start executing work. + */ + private void handleNewWorkOrder( + final KernelHolder kernelHolder, + final ControllerClient controllerClient, + final FrameProcessorExecutor workerExec, + final Set criticalWarningCodes, + final long maxVerboseParseExceptions + ) throws IOException + { + final WorkerStageKernel kernel = kernelHolder.kernel; + final WorkOrder workOrder = kernel.getWorkOrder(); + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final String cancellationId = cancellationIdFor(stageDefinition.getId()); + + log.info( + "Processing work order for stage[%s]%s", + stageDefinition.getId(), + (log.isDebugEnabled() + ? StringUtils.format(", payload[%s]", context.jsonMapper().writeValueAsString(workOrder)) : "") + ); + + final FrameContext frameContext = kernelHolder.processorCloser.register( + context.frameContext( + workOrder.getQueryDefinition(), + stageDefinition.getStageNumber(), + workOrder.getOutputChannelMode() + ) + ); + kernelHolder.processorCloser.register(() -> { + try { + workerExec.cancel(cancellationId); + } + catch (InterruptedException e) { + // Strange that cancellation would itself be interrupted. Log and suppress. + log.warn(e, "Cancellation interrupted for stage[%s]", stageDefinition.getId()); + Thread.currentThread().interrupt(); + } + }); + + // Set up cleanup functions for this work order. + kernelHolder.resultsCloser.register(() -> FileUtils.deleteDirectory(frameContext.tempDir())); + kernelHolder.resultsCloser.register(() -> removeStageOutputChannels(stageDefinition.getId())); + + // Create separate inputChannelFactory per stage, because the list of tasks can grow between stages, and + // so we need to avoid the memoization of controllerClient.getWorkerIds() in baseInputChannelFactory. + final InputChannelFactory inputChannelFactory = + makeBaseInputChannelFactory(workOrder, controllerClient, kernelHolder.processorCloser); + + // Start working on this stage immediately. + kernel.startReading(); + + final QueryContext queryContext = task != null ? QueryContext.of(task.getContext()) : QueryContext.empty(); + final RunWorkOrder runWorkOrder = new RunWorkOrder( + workOrder, + inputChannelFactory, + stageCounters.computeIfAbsent( + IntObjectPair.of(workOrder.getWorkerNumber(), stageDefinition.getId()), + ignored -> new CounterTracker() + ), + workerExec, + cancellationId, + context, + frameContext, + makeRunWorkOrderListener(workOrder, controllerClient, criticalWarningCodes, maxVerboseParseExceptions), + MultiStageQueryContext.isReindex(queryContext), + MultiStageQueryContext.removeNullBytes(queryContext) + ); + + runWorkOrder.start(); + kernelHolder.partitionBoundariesFuture = runWorkOrder.getStagePartitionBoundariesFuture(); + } + + /** + * Handle a kernel in state {@link WorkerStagePhase#READING_INPUT}. + * + * If the worker has finished generating result key statistics, they are posted to the controller and the kernel is + * transitioned to {@link WorkerStagePhase#PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES}. + * + * @return whether kernel state changed + */ + private boolean handleReadingInput( + final KernelHolder kernelHolder, + final ControllerClient controllerClient + ) throws IOException + { + final WorkerStageKernel kernel = kernelHolder.kernel; + if (kernel.hasResultKeyStatisticsSnapshot()) { + if (controllerAlive) { + PartialKeyStatisticsInformation partialKeyStatisticsInformation = + kernel.getResultKeyStatisticsSnapshot() + .partialKeyStatistics(); + + controllerClient.postPartialKeyStatistics( + kernel.getStageDefinition().getId(), + kernel.getWorkOrder().getWorkerNumber(), + partialKeyStatisticsInformation + ); + } + + kernel.startPreshuffleWaitingForResultPartitionBoundaries(); + return true; + } else if (kernel.isDoneReadingInput() + && kernel.getStageDefinition().doesSortDuringShuffle() + && !kernel.getStageDefinition().mustGatherResultKeyStatistics()) { + // Skip postDoneReadingInput when context.maxConcurrentStages() == 1, for backwards compatibility. + // See Javadoc comment on ControllerClient#postDoneReadingInput. + if (controllerAlive && context.maxConcurrentStages() > 1) { + controllerClient.postDoneReadingInput( + kernel.getStageDefinition().getId(), + kernel.getWorkOrder().getWorkerNumber() + ); + } + + kernel.startPreshuffleWritingOutput(); + return true; + } else { + return false; + } + } + + /** + * Handle a kernel in state {@link WorkerStagePhase#PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES}. + * + * If partition boundaries have become available, the {@link KernelHolder#partitionBoundariesFuture} is updated and + * the kernel is transitioned to state {@link WorkerStagePhase#PRESHUFFLE_WRITING_OUTPUT}. + * + * @return whether kernel state changed + */ + private boolean handleWaitingForResultPartitionBoundaries(final KernelHolder kernelHolder) + { + if (kernelHolder.kernel.hasResultPartitionBoundaries()) { + kernelHolder.partitionBoundariesFuture.set(kernelHolder.kernel.getResultPartitionBoundaries()); + kernelHolder.kernel.startPreshuffleWritingOutput(); + return true; + } else { + return false; + } + } + + /** + * Handle a kernel in state {@link WorkerStagePhase#RESULTS_COMPLETE}. If {@link ControllerClient#postResultsComplete} + * has not yet been posted to the controller, it is posted at this time. Otherwise nothing happens. + * + * @return whether kernel state changed + */ + private boolean handleResultsReady(final KernelHolder kernelHolder, final ControllerClient controllerClient) + throws IOException + { + final WorkerStageKernel kernel = kernelHolder.kernel; + final boolean didNotPostYet = + kernel.addPostedResultsComplete(kernel.getStageDefinition().getId(), kernel.getWorkOrder().getWorkerNumber()); + + if (controllerAlive && didNotPostYet) { + controllerClient.postResultsComplete( + kernel.getStageDefinition().getId(), + kernel.getWorkOrder().getWorkerNumber(), + kernel.getResultObject() + ); + } + + return didNotPostYet; + } + + /** + * Handle a kernel in state where {@link WorkerStagePhase#isTerminal()} is true. + */ + private void handleTerminated(final KernelHolder kernelHolder) + { + final WorkerStageKernel kernel = kernelHolder.kernel; + removeStageOutputChannels(kernel.getStageDefinition().getId()); + + if (kernelHolder.kernel.getWorkOrder().getOutputChannelMode().isDurable()) { + removeStageDurableStorageOutput(kernel.getStageDefinition().getId()); + } + } + @Override - public void stopGracefully() + public void stop() { // stopGracefully() is called when the containing process is terminated, or when the task is canceled. - log.info("Worker task[%s] canceled.", task.getId()); + log.info("Worker id[%s] canceled.", context.workerId()); doCancel(); } + @Override + public void awaitStop() + { + FutureUtils.getUnchecked(runFuture, false); + } + @Override public void controllerFailed() { - log.info("Controller task[%s] for worker task[%s] failed. Canceling.", task.getControllerTaskId(), task.getId()); + log.info( + "Controller task[%s] for worker[%s] failed. Canceling.", + task != null ? task.getControllerTaskId() : null, + id() + ); doCancel(); } @Override - public InputStream readChannel( - final String queryId, - final int stageNumber, + public ListenableFuture readChannel( + final StageId stageId, final int partitionNumber, final long offset - ) throws IOException + ) { - final StageId stageId = new StageId(queryId, stageNumber); - final StagePartition stagePartition = new StagePartition(stageId, partitionNumber); - final ConcurrentHashMap partitionOutputsForStage = stageOutputs.get(stageId); + return getOrCreateStageOutputHolder(stageId, partitionNumber).readRemotelyFrom(offset); + } - if (partitionOutputsForStage == null) { - return null; - } - final ReadableFrameChannel channel = partitionOutputsForStage.get(partitionNumber); + @Override + public void postWorkOrder(final WorkOrder workOrder) + { + log.info( + "Got work order for stage[%s], workerNumber[%s]", + workOrder.getStageDefinition().getId(), + workOrder.getWorkerNumber() + ); - if (channel == null) { - return null; + if (task != null && task.getWorkerNumber() != workOrder.getWorkerNumber()) { + throw new ISE( + "Worker number mismatch: expected workerNumber[%d], got[%d]", + task.getWorkerNumber(), + workOrder.getWorkerNumber() + ); } - if (channel instanceof ReadableNilFrameChannel) { - // Build an empty frame file. - final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - FrameFileWriter.open(Channels.newChannel(baos), null, ByteTracker.unboundedTracker()).close(); + final OutputChannelMode outputChannelMode; - final ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray()); - - //noinspection ResultOfMethodCallIgnored: OK to ignore since "skip" always works for ByteArrayInputStream. - in.skip(offset); - - return in; - } else if (channel instanceof ReadableFileFrameChannel) { - // Close frameFile once we've returned an input stream: no need to retain a reference to the mmap after that, - // since we aren't using it. - try (final FrameFile frameFile = ((ReadableFileFrameChannel) channel).newFrameFileReference()) { - final RandomAccessFile randomAccessFile = new RandomAccessFile(frameFile.file(), "r"); - - if (offset >= randomAccessFile.length()) { - randomAccessFile.close(); - return new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY); - } else { - randomAccessFile.seek(offset); - return Channels.newInputStream(randomAccessFile.getChannel()); - } - } + // This stack of conditions can be removed once we can rely on OutputChannelMode always being in the WorkOrder. + // (It will be there for newer controllers; this is a backwards-compatibility thing.) + if (workOrder.hasOutputChannelMode()) { + outputChannelMode = workOrder.getOutputChannelMode(); } else { - String errorMsg = StringUtils.format( - "Returned server error to client because channel for [%s] is not nil or file-based (class = %s)", - stagePartition, - channel.getClass().getName() + final MSQSelectDestination selectDestination = + task != null + ? MultiStageQueryContext.getSelectDestination(QueryContext.of(task.getContext())) + : MSQSelectDestination.TASKREPORT; + + outputChannelMode = ControllerQueryKernelUtils.getOutputChannelMode( + workOrder.getQueryDefinition(), + workOrder.getStageNumber(), + selectDestination, + task != null && MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(task.getContext())), + false ); - log.error(StringUtils.encodeForFormat(errorMsg)); - - throw new IOException(errorMsg); - } - } - - @Override - public void postWorkOrder(final WorkOrder workOrder) - { - log.info("Got work order for stage [%d]", workOrder.getStageNumber()); - if (task.getWorkerNumber() != workOrder.getWorkerNumber()) { - throw new ISE("Worker number mismatch: expected [%d]", task.getWorkerNumber()); } - // Do not add to queue if workerOrder already present. + final WorkOrder workOrderToUse = workOrder.withOutputChannelMode(outputChannelMode); kernelManipulationQueue.add( - kernelHolder -> - kernelHolder.getStageKernelMap().putIfAbsent( - workOrder.getStageDefinition().getId(), - WorkerStageKernel.create(workOrder) - ) + kernelHolders -> + kernelHolders.addKernel(WorkerStageKernel.create(workOrderToUse)) ); } @Override public boolean postResultPartitionBoundaries( - final ClusterByPartitions stagePartitionBoundaries, - final String queryId, - final int stageNumber + final StageId stageId, + final ClusterByPartitions stagePartitionBoundaries ) { - final StageId stageId = new StageId(queryId, stageNumber); - kernelManipulationQueue.add( - kernelHolder -> { - final WorkerStageKernel stageKernel = kernelHolder.getStageKernelMap().get(stageId); + kernelHolders -> { + final WorkerStageKernel stageKernel = kernelHolders.getKernelFor(stageId); if (stageKernel != null) { if (!stageKernel.hasResultPartitionBoundaries()) { stageKernel.setResultPartitionBoundaries(stagePartitionBoundaries); } else { // Ignore if partition boundaries are already set. - log.warn( - "Stage[%s] already has result partition boundaries set. Ignoring the latest partition boundaries recieved.", - stageId - ); + log.warn("Stage[%s] already has result partition boundaries set. Ignoring new ones.", stageId); } - } else { - // Ignore the update if we don't have a kernel for this stage. - log.warn("Ignored result partition boundaries call for unknown stage [%s]", stageId); } } ); @@ -623,167 +627,230 @@ public boolean postResultPartitionBoundaries( @Override public void postCleanupStage(final StageId stageId) { - log.info("Cleanup order for stage [%s] received", stageId); - kernelManipulationQueue.add( - holder -> { - cleanStageOutput(stageId, true); - // Mark the stage as FINISHED - WorkerStageKernel stageKernel = holder.getStageKernelMap().get(stageId); - if (stageKernel == null) { - log.warn("Stage id [%s] non existent. Unable to mark the stage kernel for it as FINISHED", stageId); - } else { - stageKernel.setStageFinished(); - } - } - ); + log.debug("Received cleanup order for stage[%s].", stageId); + kernelManipulationQueue.add(holder -> { + holder.finishProcessing(stageId); + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.setStageFinished(); + } + }); } @Override public void postFinish() { - log.info("Finish received for task [%s]", task.getId()); - kernelManipulationQueue.add(KernelHolder::setDone); + log.debug("Received finish call."); + kernelManipulationQueue.add(KernelHolders::setDone); } @Override public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId) { - log.info("Fetching statistics for stage [%d]", stageId.getStageNumber()); - if (stageKernelMap.get(stageId) == null) { - throw new ISE("Requested statistics snapshot for non-existent stageId %s.", stageId); - } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) { - throw new ISE( - "Requested statistics snapshot is not generated yet for stageId [%s]", - stageId - ); - } else { - return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot(); - } + log.debug("Fetching statistics for stage[%s]", stageId); + final SettableFuture snapshotFuture = SettableFuture.create(); + kernelManipulationQueue.add( + holder -> { + try { + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + final ClusterByStatisticsSnapshot snapshot = kernel.getResultKeyStatisticsSnapshot(); + if (snapshot == null) { + throw new ISE("Requested statistics snapshot is not generated yet for stage [%s]", stageId); + } + + snapshotFuture.set(snapshot); + } else { + snapshotFuture.setException(new ISE("Stage[%s] has terminated", stageId)); + } + } + catch (Throwable t) { + snapshotFuture.setException(t); + } + } + ); + return FutureUtils.getUnchecked(snapshotFuture, true); } @Override public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk) { - log.debug( - "Fetching statistics for stage [%d] with time chunk [%d] ", - stageId.getStageNumber(), - timeChunk - ); - if (stageKernelMap.get(stageId) == null) { - throw new ISE("Requested statistics snapshot for non-existent stageId [%s].", stageId); - } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) { - throw new ISE( - "Requested statistics snapshot is not generated yet for stageId [%s]", - stageId - ); - } else { - return stageKernelMap.get(stageId) - .getResultKeyStatisticsSnapshot() - .getSnapshotForTimeChunk(timeChunk); - } - + return fetchStatisticsSnapshot(stageId).getSnapshotForTimeChunk(timeChunk); } - @Override public CounterSnapshotsTree getCounters() { final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); - for (final Map.Entry entry : stageCounters.entrySet()) { - retVal.put(entry.getKey().getStageNumber(), task().getWorkerNumber(), entry.getValue().snapshot()); + for (final Map.Entry, CounterTracker> entry : stageCounters.entrySet()) { + retVal.put( + entry.getKey().right().getStageNumber(), + entry.getKey().leftInt(), + entry.getValue().snapshot() + ); } return retVal; } - private InputChannelFactory makeBaseInputChannelFactory(final Closer closer) + /** + * Create a {@link RunWorkOrderListener} for {@link RunWorkOrder} that hooks back into the {@link KernelHolders} + * in the main loop. + */ + private RunWorkOrderListener makeRunWorkOrderListener( + final WorkOrder workOrder, + final ControllerClient controllerClient, + final Set criticalWarningCodes, + final long maxVerboseParseExceptions + ) { - final Supplier> workerTaskList = Suppliers.memoize( - () -> { - try { - return controllerClient.getTaskList(); - } - catch (IOException e) { - throw new RuntimeException(e); + final StageId stageId = workOrder.getStageDefinition().getId(); + final MSQWarningReportPublisher msqWarningReportPublisher = new MSQWarningReportLimiterPublisher( + new MSQWarningReportSimplePublisher( + id(), + controllerClient, + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode) + ), + Limits.MAX_VERBOSE_WARNINGS, + ImmutableMap.of(CannotParseExternalDataFault.CODE, maxVerboseParseExceptions), + criticalWarningCodes, + controllerClient, + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode) + ); + + return new RunWorkOrderListener() + { + @Override + public void onDoneReadingInput(@Nullable ClusterByStatisticsSnapshot snapshot) + { + kernelManipulationQueue.add( + holder -> { + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.setResultKeyStatisticsSnapshot(snapshot); + } + } + ); + } + + @Override + public void onOutputChannelAvailable(OutputChannel channel) + { + ReadableFrameChannel readableChannel = null; + + try { + readableChannel = channel.getReadableChannel(); + getOrCreateStageOutputHolder(stageId, channel.getPartitionNumber()) + .setChannel(readableChannel); + } + catch (Exception e) { + if (readableChannel != null) { + try { + readableChannel.close(); + } + catch (Throwable e2) { + e.addSuppressed(e2); + } } + + kernelManipulationQueue.add(holder -> { + throw new RE(e, "Worker completion callback error for stage [%s]", stageId); + }); } - )::get; + } - if (durableStageStorageEnabled) { - return DurableStorageInputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - MSQTasks.makeStorageConnector(context.injector()), - closer, - false - ); - } else { - return new WorkerOrLocalInputChannelFactory(workerTaskList); - } - } + @Override + public void onSuccess(Object resultObject) + { + kernelManipulationQueue.add( + holder -> { + // Call finishProcessing prior to transitioning to RESULTS_COMPLETE, so the FrameContext is closed + // and resources are released. + holder.finishProcessing(stageId); + + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.setResultsComplete(resultObject); + } + } + ); + } - private OutputChannelFactory makeStageOutputChannelFactory( - final FrameContext frameContext, - final int stageNumber, - boolean isFinalStage - ) - { - // Use the standard frame size, since we assume this size when computing how much is needed to merge output - // files from different workers. - final int frameSize = frameContext.memoryParameters().getStandardFrameSize(); - - if (durableStageStorageEnabled || (isFinalStage - && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination))) { - return DurableStorageOutputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - task().getWorkerNumber(), - stageNumber, - task().getId(), - frameSize, - MSQTasks.makeStorageConnector(context.injector()), - context.tempDir(), - (isFinalStage && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination)) - ); - } else { - final File fileChannelDirectory = - new File(context.tempDir(), StringUtils.format("output_stage_%06d", stageNumber)); + @Override + public void onWarning(Throwable t) + { + msqWarningReportPublisher.publishException(stageId.getStageNumber(), t); + } - return new FileOutputChannelFactory(fileChannelDirectory, frameSize, null); - } + @Override + public void onFailure(Throwable t) + { + kernelManipulationQueue.add( + holder -> { + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.fail(t); + } + } + ); + } + }; } - private OutputChannelFactory makeSuperSorterIntermediateOutputChannelFactory( - final FrameContext frameContext, - final int stageNumber, - final File tmpDir + private InputChannelFactory makeBaseInputChannelFactory( + final WorkOrder workOrder, + final ControllerClient controllerClient, + final Closer closer ) { - final int frameSize = frameContext.memoryParameters().getLargeFrameSize(); - final File fileChannelDirectory = - new File(tmpDir, StringUtils.format("intermediate_output_stage_%06d", stageNumber)); - final FileOutputChannelFactory fileOutputChannelFactory = - new FileOutputChannelFactory(fileChannelDirectory, frameSize, intermediateSuperSorterLocalStorageTracker); - - if (durableStageStorageEnabled && workerStorageParameters.isIntermediateStorageLimitConfigured()) { - return new ComposingOutputChannelFactory( - ImmutableList.of( - fileOutputChannelFactory, - DurableStorageOutputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - task().getWorkerNumber(), - stageNumber, - task().getId(), - frameSize, + return MetaInputChannelFactory.create( + InputSlices.allStageSlices(workOrder.getInputs()), + workOrder.getOutputChannelMode(), + outputChannelMode -> { + switch (outputChannelMode) { + case MEMORY: + case LOCAL_STORAGE: + final Supplier> workerIds; + + if (workOrder.getWorkerIds() != null) { + workerIds = workOrder::getWorkerIds; + } else { + workerIds = Suppliers.memoize( + () -> { + try { + return controllerClient.getWorkerIds(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + ); + } + + return new WorkerOrLocalInputChannelFactory( + id(), + workerIds, + new WorkerInputChannelFactory(workerClient, workerIds), + this::getOrCreateStageOutputHolder + ); + + case DURABLE_STORAGE_INTERMEDIATE: + case DURABLE_STORAGE_QUERY_RESULTS: + return DurableStorageInputChannelFactory.createStandardImplementation( + task.getControllerTaskId(), MSQTasks.makeStorageConnector(context.injector()), - tmpDir, - false - ) - ), - frameSize - ); - } else { - return fileOutputChannelFactory; - } + closer, + outputChannelMode == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS + ); + + default: + throw DruidException.defensive("No handling for output channel mode[%s]", outputChannelMode); + } + } + ); } /** @@ -846,69 +913,75 @@ public void run() /** * Posts all counters for this worker to the controller. */ - private void postCountersToController() throws IOException + private void postCountersToController(final ControllerClient controllerClient) throws IOException { final CounterSnapshotsTree snapshotsTree = getCounters(); if (controllerAlive && !snapshotsTree.isEmpty()) { - try { - controllerClient.postCounters(id(), snapshotsTree); - } - catch (IOException e) { - if (e.getCause() instanceof ServiceClosedException) { - // Suppress. This can happen if the controller goes away while a postCounters call is in flight. - log.debug(e, "Ignoring failure on postCounters, because controller has gone away."); - } else { - throw e; - } - } + controllerClient.postCounters(id(), snapshotsTree); } } /** - * Cleans up the stage outputs corresponding to the provided stage id. It essentially calls {@code doneReading()} on - * the readable channels corresponding to all the partitions for that stage, and removes it from the {@code stageOutputs} - * map + * Removes and closes all output channels for a stage from {@link #stageOutputs}. */ - private void cleanStageOutput(final StageId stageId, boolean removeDurableStorageFiles) + private void removeStageOutputChannels(final StageId stageId) { // This code is thread-safe because remove() on ConcurrentHashMap will remove and return the removed channel only for // one thread. For the other threads it will return null, therefore we will call doneReading for a channel only once - final ConcurrentHashMap partitionOutputsForStage = stageOutputs.remove(stageId); + final ConcurrentHashMap partitionOutputsForStage = stageOutputs.remove(stageId); // Check for null, this can be the case if this method is called simultaneously from multiple threads. if (partitionOutputsForStage == null) { return; } for (final int partition : partitionOutputsForStage.keySet()) { - final ReadableFrameChannel output = partitionOutputsForStage.remove(partition); - if (output == null) { - continue; + final StageOutputHolder output = partitionOutputsForStage.remove(partition); + if (output != null) { + output.close(); } - output.close(); } + } + + /** + * Remove outputs from durable storage for a particular stage. + */ + private void removeStageDurableStorageOutput(final StageId stageId) + { // One caveat with this approach is that in case of a worker crash, while the MM/Indexer systems will delete their // temp directories where intermediate results were stored, it won't be the case for the external storage. // Therefore, the logic for cleaning the stage output in case of a worker/machine crash has to be external. // We currently take care of this in the controller. - if (durableStageStorageEnabled && removeDurableStorageFiles) { - final String folderName = DurableStorageUtils.getTaskIdOutputsFolderName( - task.getControllerTaskId(), - stageId.getStageNumber(), - task.getWorkerNumber(), - task.getId() - ); - try { - MSQTasks.makeStorageConnector(context.injector()).deleteRecursively(folderName); - } - catch (Exception e) { - // If an error is thrown while cleaning up a file, log it and try to continue with the cleanup - log.warn(e, "Error while cleaning up folder at path " + folderName); - } + final String folderName = DurableStorageUtils.getTaskIdOutputsFolderName( + task.getControllerTaskId(), + stageId.getStageNumber(), + task.getWorkerNumber(), + context.workerId() + ); + try { + MSQTasks.makeStorageConnector(context.injector()).deleteRecursively(folderName); + } + catch (Exception e) { + // If an error is thrown while cleaning up a file, log it and try to continue with the cleanup + log.warn(e, "Error while cleaning up durable storage path[%s].", folderName); } } + private StageOutputHolder getOrCreateStageOutputHolder(final StageId stageId, final int partitionNumber) + { + return stageOutputs.computeIfAbsent(stageId, ignored1 -> new ConcurrentHashMap<>()) + .computeIfAbsent(partitionNumber, ignored -> new StageOutputHolder()); + } + /** - * Called by {@link #stopGracefully()} (task canceled, or containing process shut down) and + * Returns cancellation ID for a particular stage, to be used in {@link FrameProcessorExecutor#cancel(String)}. + */ + private static String cancellationIdFor(final StageId stageId) + { + return stageId.toString(); + } + + /** + * Called by {@link #stop()} (task canceled, or containing process shut down) and * {@link #controllerFailed()}. */ private void doCancel() @@ -935,15 +1008,15 @@ private void doCancel() /** * Log (at DEBUG level) a string explaining the status of all work assigned to this worker. */ - private static void logKernelStatus(final Collection kernels) + private static void logKernelStatus(final Iterable kernels) { if (log.isDebugEnabled()) { log.debug( "Stages: %s", - kernels.stream() - .sorted(Comparator.comparing(k -> k.getStageDefinition().getStageNumber())) - .map(WorkerImpl::makeKernelStageStatusString) - .collect(Collectors.joining("; ")) + StreamSupport.stream(kernels.spliterator(), false) + .sorted(Comparator.comparing(k -> k.getStageDefinition().getStageNumber())) + .map(WorkerImpl::makeKernelStageStatusString) + .collect(Collectors.joining("; ")) ); } } @@ -978,936 +1051,205 @@ private static String makeKernelStageStatusString(final WorkerStageKernel kernel } /** - * An {@link InputChannelFactory} that loads data locally when possible, and otherwise connects directly to other - * workers. Used when durable shuffle storage is off. + * Holds {@link WorkerStageKernel} and {@link Closer}, one per {@link WorkOrder}. Also holds {@link ControllerClient}. + * Only manipulated by the main loop. Other threads that need to manipulate kernels must do so through + * {@link #kernelManipulationQueue}. */ - private class WorkerOrLocalInputChannelFactory implements InputChannelFactory + private static class KernelHolders implements Closeable { - private final Supplier> taskList; - private final WorkerInputChannelFactory workerInputChannelFactory; + private final WorkerContext workerContext; + private final ControllerClient controllerClient; - public WorkerOrLocalInputChannelFactory(final Supplier> taskList) - { - this.workerInputChannelFactory = new WorkerInputChannelFactory(workerClient, taskList); - this.taskList = taskList; - } - - @Override - public ReadableFrameChannel openChannel(StageId stageId, int workerNumber, int partitionNumber) - { - final String taskId = taskList.get().get(workerNumber); - if (taskId.equals(id())) { - final ConcurrentMap partitionOutputsForStage = stageOutputs.get(stageId); - if (partitionOutputsForStage == null) { - throw new ISE("Unable to find outputs for stage [%s]", stageId); - } - - final ReadableFrameChannel myChannel = partitionOutputsForStage.get(partitionNumber); + /** + * Stage number -> kernel holder. + */ + private final Int2ObjectMap holderMap = new Int2ObjectOpenHashMap<>(); - if (myChannel instanceof ReadableFileFrameChannel) { - // Must duplicate the channel to avoid double-closure upon task cleanup. - final FrameFile frameFile = ((ReadableFileFrameChannel) myChannel).newFrameFileReference(); - return new ReadableFileFrameChannel(frameFile); - } else if (myChannel instanceof ReadableNilFrameChannel) { - return myChannel; - } else { - throw new ISE("Output for stage [%s] are stored in an instance of %s which is not " - + "supported", stageId, myChannel.getClass()); - } - } else { - return workerInputChannelFactory.openChannel(stageId, workerNumber, partitionNumber); - } - } - } + private boolean done = false; - /** - * Main worker logic for executing a {@link WorkOrder}. - */ - private class RunWorkOrder - { - private final WorkerStageKernel kernel; - private final InputChannelFactory inputChannelFactory; - private final CounterTracker counterTracker; - private final FrameProcessorExecutor exec; - private final String cancellationId; - private final int parallelism; - private final FrameContext frameContext; - private final MSQWarningReportPublisher warningPublisher; - - private InputSliceReader inputSliceReader; - private OutputChannelFactory workOutputChannelFactory; - private OutputChannelFactory shuffleOutputChannelFactory; - private ResultAndChannels workResultAndOutputChannels; - private SettableFuture stagePartitionBoundariesFuture; - private ListenableFuture shuffleOutputChannelsFuture; - - public RunWorkOrder( - final WorkerStageKernel kernel, - final InputChannelFactory inputChannelFactory, - final CounterTracker counterTracker, - final FrameProcessorExecutor exec, - final String cancellationId, - final int parallelism, - final FrameContext frameContext, - final MSQWarningReportPublisher warningPublisher - ) + private KernelHolders(final WorkerContext workerContext, final ControllerClient controllerClient) { - this.kernel = kernel; - this.inputChannelFactory = inputChannelFactory; - this.counterTracker = counterTracker; - this.exec = exec; - this.cancellationId = cancellationId; - this.parallelism = parallelism; - this.frameContext = frameContext; - this.warningPublisher = warningPublisher; + this.workerContext = workerContext; + this.controllerClient = controllerClient; } - private void start() throws IOException + public static KernelHolders create(final WorkerContext workerContext, final Closer closer) { - final WorkOrder workOrder = kernel.getWorkOrder(); - final StageDefinition stageDef = workOrder.getStageDefinition(); - - final boolean isFinalStage = stageDef.getStageNumber() == workOrder.getQueryDefinition() - .getFinalStageDefinition() - .getStageNumber(); - - makeInputSliceReader(); - makeWorkOutputChannelFactory(isFinalStage); - makeShuffleOutputChannelFactory(isFinalStage); - makeAndRunWorkProcessors(); - - if (stageDef.doesShuffle()) { - makeAndRunShuffleProcessors(); - } else { - // No shuffling: work output _is_ shuffle output. Retain read-only versions to reduce memory footprint. - shuffleOutputChannelsFuture = - Futures.immediateFuture(workResultAndOutputChannels.getOutputChannels().readOnly()); - } - - setUpCompletionCallbacks(isFinalStage); + return closer.register(new KernelHolders(workerContext, closer.register(workerContext.makeControllerClient()))); } /** - * Settable {@link ClusterByPartitions} future for global sort. Necessary because we don't know ahead of time - * what the boundaries will be. The controller decides based on statistics from all workers. Once the controller - * decides, its decision is written to this future, which allows sorting on workers to proceed. + * Add a {@link WorkerStageKernel} to this holder. Also creates a {@link ControllerClient} for the query ID + * if one does not yet exist. Does nothing if a kernel with the same {@link StageId} is already being tracked. */ - @Nullable - public SettableFuture getStagePartitionBoundariesFuture() - { - return stagePartitionBoundariesFuture; - } - - private void makeInputSliceReader() + public void addKernel(final WorkerStageKernel kernel) { - if (inputSliceReader != null) { - throw new ISE("inputSliceReader already created"); - } - - final WorkOrder workOrder = kernel.getWorkOrder(); - final String queryId = workOrder.getQueryDefinition().getQueryId(); - - final InputChannels inputChannels = - new InputChannelsImpl( - workOrder.getQueryDefinition(), - InputSlices.allReadablePartitions(workOrder.getInputs()), - inputChannelFactory, - () -> ArenaMemoryAllocator.createOnHeap(frameContext.memoryParameters().getStandardFrameSize()), - exec, - cancellationId, - MultiStageQueryContext.removeNullBytes(QueryContext.of(task.getContext())) - ); - - inputSliceReader = new MapInputSliceReader( - ImmutableMap., InputSliceReader>builder() - .put(NilInputSlice.class, NilInputSliceReader.INSTANCE) - .put(StageInputSlice.class, new StageInputSliceReader(queryId, inputChannels)) - .put(ExternalInputSlice.class, new ExternalInputSliceReader(frameContext.tempDir())) - .put(InlineInputSlice.class, new InlineInputSliceReader(frameContext.segmentWrangler())) - .put(LookupInputSlice.class, new LookupInputSliceReader(frameContext.segmentWrangler())) - .put( - SegmentsInputSlice.class, - new SegmentsInputSliceReader( - frameContext, - MultiStageQueryContext.isReindex(QueryContext.of(task().getContext())) - ) - ) - .build() - ); - } - - private void makeWorkOutputChannelFactory(boolean isFinalStage) - { - if (workOutputChannelFactory != null) { - throw new ISE("processorOutputChannelFactory already created"); - } - - final OutputChannelFactory baseOutputChannelFactory; - - if (kernel.getStageDefinition().doesShuffle()) { - // Writing to a consumer in the same JVM (which will be set up later on in this method). Use the large frame - // size if we're writing to a SuperSorter, since we'll generate fewer temp files if we use larger frames. - // Otherwise, use the standard frame size. - final int frameSize; - - if (kernel.getStageDefinition().getShuffleSpec().kind().isSort()) { - frameSize = frameContext.memoryParameters().getLargeFrameSize(); - } else { - frameSize = frameContext.memoryParameters().getStandardFrameSize(); - } + final StageId stageId = verifyQueryId(kernel.getWorkOrder().getStageDefinition().getId()); - baseOutputChannelFactory = new BlockingQueueOutputChannelFactory(frameSize); - } else { - // Writing stage output. - baseOutputChannelFactory = - makeStageOutputChannelFactory(frameContext, kernel.getStageDefinition().getStageNumber(), isFinalStage); + if (holderMap.putIfAbsent(stageId.getStageNumber(), new KernelHolder(kernel)) != null) { + // Already added. Do nothing. } - - workOutputChannelFactory = new CountingOutputChannelFactory( - baseOutputChannelFactory, - counterTracker.channel(CounterNames.outputChannel()) - ); - } - - private void makeShuffleOutputChannelFactory(boolean isFinalStage) - { - shuffleOutputChannelFactory = - new CountingOutputChannelFactory( - makeStageOutputChannelFactory(frameContext, kernel.getStageDefinition().getStageNumber(), isFinalStage), - counterTracker.channel(CounterNames.shuffleChannel()) - ); } /** - * Use {@link FrameProcessorFactory#makeProcessors} to create {@link ProcessorsAndChannels}. Executes the - * processors using {@link #exec} and sets the output channels in {@link #workResultAndOutputChannels}. + * Called when processing for a stage is complete. Releases processing resources associated with the stage, i.e., + * those that are part of {@link KernelHolder#processorCloser}. * - * @param type of {@link StageDefinition#getProcessorFactory()} - * @param return type of {@link FrameProcessor} created by the manager - * @param result type of {@link ProcessorManager#result()} - * @param type of {@link WorkOrder#getExtraInfo()} + * Does not release results-fetching resources, i.e., does not release {@link KernelHolder#resultsCloser}. Those + * resources are released on {@link #removeKernel(StageId)} only. */ - private , ProcessorReturnType, ManagerReturnType, ExtraInfoType> void makeAndRunWorkProcessors() - throws IOException - { - if (workResultAndOutputChannels != null) { - throw new ISE("workResultAndOutputChannels already set"); - } - - @SuppressWarnings("unchecked") - final FactoryType processorFactory = (FactoryType) kernel.getStageDefinition().getProcessorFactory(); - - @SuppressWarnings("unchecked") - final ProcessorsAndChannels processors = - processorFactory.makeProcessors( - kernel.getStageDefinition(), - kernel.getWorkOrder().getWorkerNumber(), - kernel.getWorkOrder().getInputs(), - inputSliceReader, - (ExtraInfoType) kernel.getWorkOrder().getExtraInfo(), - workOutputChannelFactory, - frameContext, - parallelism, - counterTracker, - e -> warningPublisher.publishException(kernel.getStageDefinition().getStageNumber(), e), - isRemoveNullBytes - ); - - final ProcessorManager processorManager = processors.getProcessorManager(); - - final int maxOutstandingProcessors; - - if (processors.getOutputChannels().getAllChannels().isEmpty()) { - // No output channels: run up to "parallelism" processors at once. - maxOutstandingProcessors = Math.max(1, parallelism); - } else { - // If there are output channels, that acts as a ceiling on the number of processors that can run at once. - maxOutstandingProcessors = - Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size())); - } - - final ListenableFuture workResultFuture = exec.runAllFully( - processorManager, - maxOutstandingProcessors, - processorBouncer, - cancellationId - ); - - workResultAndOutputChannels = new ResultAndChannels<>(workResultFuture, processors.getOutputChannels()); - } - - private void makeAndRunShuffleProcessors() + public void finishProcessing(final StageId stageId) { - if (shuffleOutputChannelsFuture != null) { - throw new ISE("shuffleOutputChannelsFuture already set"); - } - - final ShuffleSpec shuffleSpec = kernel.getWorkOrder().getStageDefinition().getShuffleSpec(); - - final ShufflePipelineBuilder shufflePipeline = new ShufflePipelineBuilder( - kernel, - counterTracker, - exec, - cancellationId, - frameContext - ); - - shufflePipeline.initialize(workResultAndOutputChannels); - - switch (shuffleSpec.kind()) { - case MIX: - shufflePipeline.mix(shuffleOutputChannelFactory); - break; - - case HASH: - shufflePipeline.hashPartition(shuffleOutputChannelFactory); - break; - - case HASH_LOCAL_SORT: - final OutputChannelFactory hashOutputChannelFactory; - - if (shuffleSpec.partitionCount() == 1) { - // Single partition; no need to write temporary files. - hashOutputChannelFactory = - new BlockingQueueOutputChannelFactory(frameContext.memoryParameters().getStandardFrameSize()); - } else { - // Multi-partition; write temporary files and then sort each one file-by-file. - hashOutputChannelFactory = - new FileOutputChannelFactory( - context.tempDir(kernel.getStageDefinition().getStageNumber(), "hash-parts"), - frameContext.memoryParameters().getStandardFrameSize(), - null - ); - } - - shufflePipeline.hashPartition(hashOutputChannelFactory); - shufflePipeline.localSort(shuffleOutputChannelFactory); - break; - - case GLOBAL_SORT: - shufflePipeline.gatherResultKeyStatisticsIfNeeded(); - shufflePipeline.globalSort(shuffleOutputChannelFactory, makeGlobalSortPartitionBoundariesFuture()); - break; - - default: - throw new UOE("Cannot handle shuffle kind [%s]", shuffleSpec.kind()); - } + final KernelHolder kernel = holderMap.get(verifyQueryId(stageId).getStageNumber()); - shuffleOutputChannelsFuture = shufflePipeline.build(); - } - - private ListenableFuture makeGlobalSortPartitionBoundariesFuture() - { - if (kernel.getStageDefinition().mustGatherResultKeyStatistics()) { - if (stagePartitionBoundariesFuture != null) { - throw new ISE("Cannot call 'makeGlobalSortPartitionBoundariesFuture' twice"); + if (kernel != null) { + try { + kernel.processorCloser.close(); + } + catch (IOException e) { + throw new RuntimeException(e); } - - return (stagePartitionBoundariesFuture = SettableFuture.create()); - } else { - return Futures.immediateFuture(kernel.getResultPartitionBoundaries()); } } - private void setUpCompletionCallbacks(boolean isFinalStage) - { - final StageDefinition stageDef = kernel.getStageDefinition(); - - Futures.addCallback( - Futures.allAsList( - Arrays.asList( - workResultAndOutputChannels.getResultFuture(), - shuffleOutputChannelsFuture - ) - ), - new FutureCallback>() - { - @Override - public void onSuccess(final List workerResultAndOutputChannelsResolved) - { - final Object resultObject = workerResultAndOutputChannelsResolved.get(0); - final OutputChannels outputChannels = (OutputChannels) workerResultAndOutputChannelsResolved.get(1); - - for (OutputChannel channel : outputChannels.getAllChannels()) { - try { - stageOutputs.computeIfAbsent(stageDef.getId(), ignored1 -> new ConcurrentHashMap<>()) - .computeIfAbsent(channel.getPartitionNumber(), ignored2 -> channel.getReadableChannel()); - } - catch (Exception e) { - kernelManipulationQueue.add(holder -> { - throw new RE(e, "Worker completion callback error for stage [%s]", stageDef.getId()); - }); - - // Don't make the "setResultsComplete" call below. - return; - } - } - - // Once the outputs channels have been resolved and are ready for reading, write success file, if - // using durable storage. - writeDurableStorageSuccessFileIfNeeded(stageDef.getStageNumber(), isFinalStage); - - kernelManipulationQueue.add(holder -> holder.getStageKernelMap() - .get(stageDef.getId()) - .setResultsComplete(resultObject)); - } - - @Override - public void onFailure(final Throwable t) - { - kernelManipulationQueue.add( - kernelHolder -> - kernelHolder.getStageKernelMap().get(stageDef.getId()).fail(t) - ); - } - }, - MoreExecutors.directExecutor() - ); - } - /** - * Write {@link DurableStorageUtils#SUCCESS_MARKER_FILENAME} for a particular stage, if durable storage is enabled. + * Remove the {@link WorkerStageKernel} for a given {@link StageId} from this holder. Closes all the associated + * {@link Closeable}. Removes and closes the {@link ControllerClient} for this query ID, if there are no longer + * any active work orders for that query ID + * + * @throws IllegalStateException if there is no active kernel for this stage */ - private void writeDurableStorageSuccessFileIfNeeded(final int stageNumber, boolean isFinalStage) + public void removeKernel(final StageId stageId) { - final DurableStorageOutputChannelFactory durableStorageOutputChannelFactory; - if (durableStageStorageEnabled || (isFinalStage - && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination))) { - durableStorageOutputChannelFactory = DurableStorageOutputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - task().getWorkerNumber(), - stageNumber, - task().getId(), - frameContext.memoryParameters().getStandardFrameSize(), - MSQTasks.makeStorageConnector(context.injector()), - context.tempDir(), - (isFinalStage && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination)) - ); - } else { - return; + final KernelHolder removed = holderMap.remove(verifyQueryId(stageId).getStageNumber()); + + if (removed == null) { + throw new ISE("No kernel for stage[%s]", stageId); } + try { - durableStorageOutputChannelFactory.createSuccessFile(task.getId()); + removed.processorCloser.close(); + removed.resultsCloser.close(); } catch (IOException e) { - throw new ISE( - e, - "Unable to create the success file [%s] at the location [%s]", - DurableStorageUtils.SUCCESS_MARKER_FILENAME, - durableStorageOutputChannelFactory.getSuccessFilePath() - ); + throw new RuntimeException(e); } } - } - - /** - * Helper for {@link RunWorkOrder#makeAndRunShuffleProcessors()}. Builds a {@link FrameProcessor} pipeline to - * handle the shuffle. - */ - private class ShufflePipelineBuilder - { - private final WorkerStageKernel kernel; - private final CounterTracker counterTracker; - private final FrameProcessorExecutor exec; - private final String cancellationId; - private final FrameContext frameContext; - - // Current state of the pipeline. It's a future to allow pipeline construction to be deferred if necessary. - private ListenableFuture> pipelineFuture; - - public ShufflePipelineBuilder( - final WorkerStageKernel kernel, - final CounterTracker counterTracker, - final FrameProcessorExecutor exec, - final String cancellationId, - final FrameContext frameContext - ) - { - this.kernel = kernel; - this.counterTracker = counterTracker; - this.exec = exec; - this.cancellationId = cancellationId; - this.frameContext = frameContext; - } /** - * Start the pipeline with the outputs of the main processor. + * Returns all currently-active kernel holders. */ - public void initialize(final ResultAndChannels resultAndChannels) + public Iterable getAllKernelHolders() { - if (pipelineFuture != null) { - throw new ISE("already initialized"); - } - - pipelineFuture = Futures.immediateFuture(resultAndChannels); + return holderMap.values(); } /** - * Add {@link FrameChannelMixer}, which mixes all current outputs into a single channel from the provided factory. + * Returns all currently-active kernels. */ - public void mix(final OutputChannelFactory outputChannelFactory) + public Iterable getAllKernels() { - // No sorting or statistics gathering, just combining all outputs into one big partition. Use a mixer to get - // everything into one file. Note: even if there is only one output channel, we'll run it through the mixer - // anyway, to ensure the data gets written to a file. (httpGetChannelData requires files.) - - push( - resultAndChannels -> { - final OutputChannel outputChannel = outputChannelFactory.openChannel(0); - - final FrameChannelMixer mixer = - new FrameChannelMixer( - resultAndChannels.getOutputChannels().getAllReadableChannels(), - outputChannel.getWritableChannel() - ); - - return new ResultAndChannels<>( - exec.runFully(mixer, cancellationId), - OutputChannels.wrap(Collections.singletonList(outputChannel.readOnly())) - ); - } - ); + return Iterables.transform(holderMap.values(), holder -> holder.kernel); } /** - * Add {@link KeyStatisticsCollectionProcessor} if {@link StageDefinition#mustGatherResultKeyStatistics()}. + * Returns the number of kernels that are in running states, where {@link WorkerStagePhase#isRunning()}. */ - public void gatherResultKeyStatisticsIfNeeded() + public int runningKernelCount() { - push( - resultAndChannels -> { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - final OutputChannels channels = resultAndChannels.getOutputChannels(); - - if (channels.getAllChannels().isEmpty()) { - // No data coming out of this processor. Report empty statistics, if the kernel is expecting statistics. - if (stageDefinition.mustGatherResultKeyStatistics()) { - kernelManipulationQueue.add( - holder -> - holder.getStageKernelMap().get(stageDefinition.getId()) - .setResultKeyStatisticsSnapshot(ClusterByStatisticsSnapshot.empty()) - ); - } - - // Generate one empty channel so the SuperSorter has something to do. - final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); - channel.writable().close(); - - final OutputChannel outputChannel = OutputChannel.readOnly( - channel.readable(), - FrameWithPartition.NO_PARTITION - ); + int retVal = 0; + for (final KernelHolder holder : holderMap.values()) { + if (holder.kernel.getPhase().isRunning()) { + retVal++; + } + } - return new ResultAndChannels<>( - Futures.immediateFuture(null), - OutputChannels.wrap(Collections.singletonList(outputChannel)) - ); - } else if (stageDefinition.mustGatherResultKeyStatistics()) { - return gatherResultKeyStatistics(channels); - } else { - return resultAndChannels; - } - } - ); + return retVal; } /** - * Add a {@link SuperSorter} using {@link StageDefinition#getSortKey()} and partition boundaries - * from {@code partitionBoundariesFuture}. + * Return the kernel for a particular {@link StageId}. + * + * @return kernel, or null if there is no active kernel for this stage */ - public void globalSort( - final OutputChannelFactory outputChannelFactory, - final ListenableFuture partitionBoundariesFuture - ) + @Nullable + public WorkerStageKernel getKernelFor(final StageId stageId) { - pushAsync( - resultAndChannels -> { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - - final File sorterTmpDir = context.tempDir(stageDefinition.getStageNumber(), "super-sort"); - FileUtils.mkdirp(sorterTmpDir); - if (!sorterTmpDir.isDirectory()) { - throw new IOException("Cannot create directory: " + sorterTmpDir); - } - - final WorkerMemoryParameters memoryParameters = frameContext.memoryParameters(); - final SuperSorter sorter = new SuperSorter( - resultAndChannels.getOutputChannels().getAllReadableChannels(), - stageDefinition.getFrameReader(), - stageDefinition.getSortKey(), - partitionBoundariesFuture, - exec, - outputChannelFactory, - makeSuperSorterIntermediateOutputChannelFactory( - frameContext, - stageDefinition.getStageNumber(), - sorterTmpDir - ), - memoryParameters.getSuperSorterMaxActiveProcessors(), - memoryParameters.getSuperSorterMaxChannelsPerProcessor(), - -1, - cancellationId, - counterTracker.sortProgress(), - isRemoveNullBytes - ); - - return FutureUtils.transform( - sorter.run(), - sortedChannels -> new ResultAndChannels<>(Futures.immediateFuture(null), sortedChannels) - ); - } - ); + final KernelHolder holder = holderMap.get(verifyQueryId(stageId).getStageNumber()); + if (holder != null) { + return holder.kernel; + } else { + return null; + } } /** - * Add a {@link FrameChannelHashPartitioner} using {@link StageDefinition#getSortKey()}. + * Retrieves the {@link ControllerClient}, which is shared across all {@link WorkOrder} for this worker. */ - public void hashPartition(final OutputChannelFactory outputChannelFactory) + public ControllerClient getControllerClient() { - pushAsync( - resultAndChannels -> { - final ShuffleSpec shuffleSpec = kernel.getStageDefinition().getShuffleSpec(); - final int partitions = shuffleSpec.partitionCount(); - - final List outputChannels = new ArrayList<>(); - - for (int i = 0; i < partitions; i++) { - outputChannels.add(outputChannelFactory.openChannel(i)); - } - - final FrameChannelHashPartitioner partitioner = new FrameChannelHashPartitioner( - resultAndChannels.getOutputChannels().getAllReadableChannels(), - outputChannels.stream().map(OutputChannel::getWritableChannel).collect(Collectors.toList()), - kernel.getStageDefinition().getFrameReader(), - kernel.getStageDefinition().getClusterBy().getColumns().size(), - FrameWriters.makeRowBasedFrameWriterFactory( - new ArenaMemoryAllocatorFactory(frameContext.memoryParameters().getStandardFrameSize()), - kernel.getStageDefinition().getSignature(), - kernel.getStageDefinition().getSortKey(), - isRemoveNullBytes - ) - ); - - final ListenableFuture partitionerFuture = exec.runFully(partitioner, cancellationId); - - final ResultAndChannels retVal = - new ResultAndChannels<>(partitionerFuture, OutputChannels.wrap(outputChannels)); - - if (retVal.getOutputChannels().areReadableChannelsReady()) { - return Futures.immediateFuture(retVal); - } else { - return FutureUtils.transform(partitionerFuture, ignored -> retVal); - } - } - ); + return controllerClient; } /** - * Add a sequence of {@link SuperSorter}, operating on each current output channel in order, one at a time. + * Remove all {@link WorkerStageKernel} and close all {@link ControllerClient}. */ - public void localSort(final OutputChannelFactory outputChannelFactory) - { - pushAsync( - resultAndChannels -> { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - final OutputChannels channels = resultAndChannels.getOutputChannels(); - final List> sortedChannelFutures = new ArrayList<>(); - - ListenableFuture nextFuture = Futures.immediateFuture(null); - - for (final OutputChannel channel : channels.getAllChannels()) { - final File sorterTmpDir = context.tempDir( - stageDefinition.getStageNumber(), - StringUtils.format("hash-parts-super-sort-%06d", channel.getPartitionNumber()) - ); - - FileUtils.mkdirp(sorterTmpDir); - - // SuperSorter will try to write to output partition zero; we remap it to the correct partition number. - final OutputChannelFactory partitionOverrideOutputChannelFactory = new OutputChannelFactory() - { - @Override - public OutputChannel openChannel(int expectedZero) throws IOException - { - if (expectedZero != 0) { - throw new ISE("Unexpected part [%s]", expectedZero); - } - - return outputChannelFactory.openChannel(channel.getPartitionNumber()); - } - - @Override - public PartitionedOutputChannel openPartitionedChannel(String name, boolean deleteAfterRead) - { - throw new UnsupportedOperationException(); - } - - @Override - public OutputChannel openNilChannel(int expectedZero) - { - if (expectedZero != 0) { - throw new ISE("Unexpected part [%s]", expectedZero); - } - - return outputChannelFactory.openNilChannel(channel.getPartitionNumber()); - } - }; - - // Chain futures so we only sort one partition at a time. - nextFuture = Futures.transformAsync( - nextFuture, - (AsyncFunction) ignored -> { - final SuperSorter sorter = new SuperSorter( - Collections.singletonList(channel.getReadableChannel()), - stageDefinition.getFrameReader(), - stageDefinition.getSortKey(), - Futures.immediateFuture(ClusterByPartitions.oneUniversalPartition()), - exec, - partitionOverrideOutputChannelFactory, - makeSuperSorterIntermediateOutputChannelFactory( - frameContext, - stageDefinition.getStageNumber(), - sorterTmpDir - ), - 1, - 2, - -1, - cancellationId, - - // Tracker is not actually tracked, since it doesn't quite fit into the way we report counters. - // There's a single SuperSorterProgressTrackerCounter per worker, but workers that do local - // sorting have a SuperSorter per partition. - new SuperSorterProgressTracker(), - isRemoveNullBytes - ); - - return FutureUtils.transform(sorter.run(), r -> Iterables.getOnlyElement(r.getAllChannels())); - }, - MoreExecutors.directExecutor() - ); - - sortedChannelFutures.add(nextFuture); - } - - return FutureUtils.transform( - Futures.allAsList(sortedChannelFutures), - sortedChannels -> new ResultAndChannels<>( - Futures.immediateFuture(null), - OutputChannels.wrap(sortedChannels) - ) - ); - } - ); - } - - /** - * Return the (future) output channels for this pipeline. - */ - public ListenableFuture build() + @Override + public void close() { - if (pipelineFuture == null) { - throw new ISE("Not initialized"); - } + for (final int stageNumber : ImmutableList.copyOf(holderMap.keySet())) { + final StageId stageId = new StageId(workerContext.queryId(), stageNumber); - return Futures.transformAsync( - pipelineFuture, - (AsyncFunction, OutputChannels>) resultAndChannels -> - Futures.transform( - resultAndChannels.getResultFuture(), - (Function) input -> { - sanityCheckOutputChannels(resultAndChannels.getOutputChannels()); - return resultAndChannels.getOutputChannels(); - }, - MoreExecutors.directExecutor() - ), - MoreExecutors.directExecutor() - ); - } - - /** - * Adds {@link KeyStatisticsCollectionProcessor}. Called by {@link #gatherResultKeyStatisticsIfNeeded()}. - */ - private ResultAndChannels gatherResultKeyStatistics(final OutputChannels channels) - { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - final List retVal = new ArrayList<>(); - final List processors = new ArrayList<>(); - - for (final OutputChannel outputChannel : channels.getAllChannels()) { - final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); - retVal.add(OutputChannel.readOnly(channel.readable(), outputChannel.getPartitionNumber())); - - processors.add( - new KeyStatisticsCollectionProcessor( - outputChannel.getReadableChannel(), - channel.writable(), - stageDefinition.getFrameReader(), - stageDefinition.getClusterBy(), - stageDefinition.createResultKeyStatisticsCollector( - frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() - ) - ) - ); + try { + removeKernel(stageId); + } + catch (Exception e) { + log.warn(e, "Failed to remove kernel for stage[%s].", stageId); + } } - - final ListenableFuture clusterByStatisticsCollectorFuture = - exec.runAllFully( - ProcessorManagers.of(processors) - .withAccumulation( - stageDefinition.createResultKeyStatisticsCollector( - frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() - ), - ClusterByStatisticsCollector::addAll - ), - // Run all processors simultaneously. They are lightweight and this keeps things moving. - processors.size(), - Bouncer.unlimited(), - cancellationId - ); - - Futures.addCallback( - clusterByStatisticsCollectorFuture, - new FutureCallback() - { - @Override - public void onSuccess(final ClusterByStatisticsCollector result) - { - result.logSketches(); - kernelManipulationQueue.add( - holder -> - holder.getStageKernelMap().get(stageDefinition.getId()) - .setResultKeyStatisticsSnapshot(result.snapshot()) - ); - } - - @Override - public void onFailure(Throwable t) - { - kernelManipulationQueue.add( - holder -> { - log.noStackTrace() - .warn(t, "Failed to gather clusterBy statistics for stage [%s]", stageDefinition.getId()); - holder.getStageKernelMap().get(stageDefinition.getId()).fail(t); - } - ); - } - }, - MoreExecutors.directExecutor() - ); - - return new ResultAndChannels<>( - clusterByStatisticsCollectorFuture, - OutputChannels.wrap(retVal) - ); } /** - * Update the {@link #pipelineFuture}. + * Check whether {@link #setDone()} has been called. */ - private void push(final ExceptionalFunction, ResultAndChannels> fn) + public boolean isDone() { - pushAsync( - channels -> - Futures.immediateFuture(fn.apply(channels)) - ); + return done; } /** - * Update the {@link #pipelineFuture} asynchronously. + * Mark the holder as "done", signaling to the main loop that it should clean up and exit as soon as possible. */ - private void pushAsync(final ExceptionalFunction, ListenableFuture>> fn) + public void setDone() { - if (pipelineFuture == null) { - throw new ISE("Not initialized"); - } - - pipelineFuture = FutureUtils.transform( - Futures.transformAsync( - pipelineFuture, - new AsyncFunction, ResultAndChannels>() - { - @Override - public ListenableFuture> apply(ResultAndChannels t) throws Exception - { - return fn.apply(t); - } - }, - MoreExecutors.directExecutor() - ), - resultAndChannels -> new ResultAndChannels<>( - resultAndChannels.getResultFuture(), - resultAndChannels.getOutputChannels().readOnly() - ) - ); + this.done = true; } - /** - * Verifies there is exactly one channel per partition. - */ - private void sanityCheckOutputChannels(final OutputChannels outputChannels) + private StageId verifyQueryId(final StageId stageId) { - for (int partitionNumber : outputChannels.getPartitionNumbers()) { - final List outputChannelsForPartition = - outputChannels.getChannelsForPartition(partitionNumber); - - Preconditions.checkState(partitionNumber >= 0, "Expected partitionNumber >= 0, but got [%s]", partitionNumber); - Preconditions.checkState( - outputChannelsForPartition.size() == 1, - "Expected one channel for partition [%s], but got [%s]", - partitionNumber, - outputChannelsForPartition.size() - ); + if (!stageId.getQueryId().equals(workerContext.queryId())) { + throw new ISE("Unexpected queryId[%s], expected queryId[%s]", stageId.getQueryId(), workerContext.queryId()); } - } - } - - private class KernelHolder - { - private boolean done = false; - - public Map getStageKernelMap() - { - return stageKernelMap; - } - - public boolean isDone() - { - return done; - } - public void setDone() - { - this.done = true; + return stageId; } } - private static class ResultAndChannels + /** + * Holder for a single {@link WorkerStageKernel} and associated items, contained within {@link KernelHolders}. + */ + private static class KernelHolder { - private final ListenableFuture resultFuture; - private final OutputChannels outputChannels; - - public ResultAndChannels( - ListenableFuture resultFuture, - OutputChannels outputChannels - ) - { - this.resultFuture = resultFuture; - this.outputChannels = outputChannels; - } - - public ListenableFuture getResultFuture() - { - return resultFuture; - } + private final WorkerStageKernel kernel; + private final Closer processorCloser; + private final Closer resultsCloser; + private SettableFuture partitionBoundariesFuture; - public OutputChannels getOutputChannels() + public KernelHolder(WorkerStageKernel kernel) { - return outputChannels; + this.kernel = kernel; + this.processorCloser = Closer.create(); + this.resultsCloser = Closer.create(); } } - - private interface ExceptionalFunction - { - R apply(T t) throws Exception; - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index b36b1b4155a8..8d3f15c09c79 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -168,29 +168,14 @@ public class WorkerMemoryParameters this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes; } - /** - * Create a production instance for {@link org.apache.druid.msq.indexing.MSQControllerTask}. - */ - public static WorkerMemoryParameters createProductionInstanceForController(final Injector injector) - { - long totalLookupFootprint = computeTotalLookupFootprint(injector); - return createInstance( - Runtime.getRuntime().maxMemory(), - computeNumWorkersInJvm(injector), - computeNumProcessorsInJvm(injector), - 0, - 0, - totalLookupFootprint - ); - } - /** * Create a production instance for {@link org.apache.druid.msq.indexing.MSQWorkerTask}. */ public static WorkerMemoryParameters createProductionInstanceForWorker( final Injector injector, final QueryDefinition queryDef, - final int stageNumber + final int stageNumber, + final int maxConcurrentStages ) { final StageDefinition stageDef = queryDef.getStageDefinition(stageNumber); @@ -212,6 +197,7 @@ public static WorkerMemoryParameters createProductionInstanceForWorker( Runtime.getRuntime().maxMemory(), computeNumWorkersInJvm(injector), computeNumProcessorsInJvm(injector), + maxConcurrentStages, numInputWorkers, numHashOutputPartitions, totalLookupFootprint @@ -228,6 +214,7 @@ public static WorkerMemoryParameters createProductionInstanceForWorker( * @param numWorkersInJvm number of workers that can run concurrently in this JVM. Generally equal to * the task capacity. * @param numProcessingThreadsInJvm size of the processing thread pool in the JVM. + * @param maxConcurrentStages maximum number of concurrent stages per worker. * @param numInputWorkers total number of workers across all input stages. * @param numHashOutputPartitions total number of output partitions, if using hash partitioning; zero if not using * hash partitioning. @@ -237,6 +224,7 @@ public static WorkerMemoryParameters createInstance( final long maxMemoryInJvm, final int numWorkersInJvm, final int numProcessingThreadsInJvm, + final int maxConcurrentStages, final int numInputWorkers, final int numHashOutputPartitions, final long totalLookupFootprint @@ -257,7 +245,8 @@ public static WorkerMemoryParameters createInstance( ); final long usableMemoryInJvm = computeUsableMemoryInJvm(maxMemoryInJvm, totalLookupFootprint); final long workerMemory = memoryPerWorker(usableMemoryInJvm, numWorkersInJvm); - final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); + final long bundleMemory = + memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm) / maxConcurrentStages; final long bundleMemoryForInputChannels = memoryNeededForInputChannels(numInputWorkers); final long bundleMemoryForHashPartitioning = memoryNeededForHashPartitioning(numHashOutputPartitions); final long bundleMemoryForProcessing = @@ -268,6 +257,7 @@ public static WorkerMemoryParameters createInstance( usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm, + maxConcurrentStages, numHashOutputPartitions ); @@ -281,7 +271,8 @@ public static WorkerMemoryParameters createInstance( estimateUsableMemory( numWorkersInJvm, numProcessingThreadsInJvm, - PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels + PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels, + maxConcurrentStages ), totalLookupFootprint), maxMemoryInJvm, usableMemoryInJvm, @@ -301,7 +292,8 @@ public static WorkerMemoryParameters createInstance( calculateSuggestedMinMemoryFromUsableMemory( estimateUsableMemory( numWorkersInJvm, - (MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE + (MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE, + maxConcurrentStages ), totalLookupFootprint ), @@ -338,7 +330,8 @@ public static WorkerMemoryParameters createInstance( estimateUsableMemory( numWorkersInJvm, numProcessingThreadsInJvm, - PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels + PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels, + maxConcurrentStages ), totalLookupFootprint), maxMemoryInJvm, usableMemoryInJvm, @@ -352,7 +345,9 @@ public static WorkerMemoryParameters createInstance( bundleMemoryForProcessing, superSorterMaxActiveProcessors, superSorterMaxChannelsPerProcessor, - Ints.checkedCast(workerMemory) // 100% of worker memory is devoted to partition statistics + + // 100% of worker memory is devoted to partition statistics + Ints.checkedCast(workerMemory / maxConcurrentStages) ); } @@ -459,18 +454,19 @@ static int computeMaxWorkers( final long usableMemoryInJvm, final int numWorkersInJvm, final int numProcessingThreadsInJvm, + final int maxConcurrentStages, final int numHashOutputPartitions ) { final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); - // Compute number of workers that gives us PROCESSING_MINIMUM_BYTES of memory per bundle, while accounting for - // memoryNeededForInputChannels + memoryNeededForHashPartitioning. + // Compute number of workers that gives us PROCESSING_MINIMUM_BYTES of memory per bundle per concurrent stage, while + // accounting for memoryNeededForInputChannels + memoryNeededForHashPartitioning. final int isHashing = numHashOutputPartitions > 0 ? 1 : 0; - return Math.max( - 0, - Ints.checkedCast((bundleMemory - PROCESSING_MINIMUM_BYTES) / ((long) STANDARD_FRAME_SIZE * (1 + isHashing)) - 1) - ); + final long bundleMemoryPerStage = bundleMemory / maxConcurrentStages; + final long maxWorkers = + (bundleMemoryPerStage - PROCESSING_MINIMUM_BYTES) / ((long) STANDARD_FRAME_SIZE * (1 + isHashing)) - 1; + return Math.max(0, Ints.checkedCast(maxWorkers)); } /** @@ -553,24 +549,28 @@ private static long memoryPerBundle( private static long estimateUsableMemory( final int numWorkersInJvm, final int numProcessingThreadsInJvm, - final long estimatedEachBundleMemory + final long estimatedEachBundleMemory, + final int maxConcurrentStages ) { final int bundleCount = numWorkersInJvm + numProcessingThreadsInJvm; - return estimateUsableMemory(numWorkersInJvm, estimatedEachBundleMemory * bundleCount); - + return estimateUsableMemory(numWorkersInJvm, estimatedEachBundleMemory * bundleCount, maxConcurrentStages); } /** * Add overheads to the estimated bundle memoery for all the workers. Checkout {@link WorkerMemoryParameters#memoryPerWorker(long, int)} * for the overhead calculation outside the processing bundles. */ - private static long estimateUsableMemory(final int numWorkersInJvm, final long estimatedTotalBundleMemory) + private static long estimateUsableMemory( + final int numWorkersInJvm, + final long estimatedTotalBundleMemory, + final int maxConcurrentStages + ) { - // Currently, we only add the partition stats overhead since it will be the single largest overhead per worker. final long estimateStatOverHeadPerWorker = PARTITION_STATS_MEMORY_MAX_BYTES; - return estimatedTotalBundleMemory + (estimateStatOverHeadPerWorker * numWorkersInJvm); + final long requiredUsableMemory = estimatedTotalBundleMemory + (estimateStatOverHeadPerWorker * numWorkersInJvm); + return requiredUsableMemory * maxConcurrentStages; } private static long memoryNeededForHashPartitioning(final int numOutputPartitions) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java index 59576ec90bfb..53e12dd2ab4d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java @@ -70,11 +70,13 @@ private WorkerStorageParameters(final long intermediateSuperSorterStorageMaxLoca public static WorkerStorageParameters createProductionInstance( final Injector injector, - final boolean isIntermediateSuperSorterStorageEnabled + final OutputChannelMode outputChannelMode ) { long tmpStorageBytesPerTask = injector.getInstance(TaskConfig.class).getTmpStorageBytesPerTask(); - return createInstance(tmpStorageBytesPerTask, isIntermediateSuperSorterStorageEnabled); + + // If durable storage is enabled, then super sorter intermediate storage should be enabled as well. + return createInstance(tmpStorageBytesPerTask, outputChannelMode.isDurable()); } @VisibleForTesting diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index e0de5bdc27e2..240400aa6d5e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -20,9 +20,13 @@ package org.apache.druid.msq.indexing; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.segment.IndexIO; @@ -32,28 +36,35 @@ import org.apache.druid.segment.loading.DataSegmentPusher; import java.io.File; +import java.io.IOException; public class IndexerFrameContext implements FrameContext { + private final StageId stageId; private final IndexerWorkerContext context; private final IndexIO indexIO; private final DataSegmentProvider dataSegmentProvider; private final WorkerMemoryParameters memoryParameters; + private final WorkerStorageParameters storageParameters; private final DataServerQueryHandlerFactory dataServerQueryHandlerFactory; public IndexerFrameContext( + StageId stageId, IndexerWorkerContext context, IndexIO indexIO, DataSegmentProvider dataSegmentProvider, DataServerQueryHandlerFactory dataServerQueryHandlerFactory, - WorkerMemoryParameters memoryParameters + WorkerMemoryParameters memoryParameters, + WorkerStorageParameters storageParameters ) { + this.stageId = stageId; this.context = context; this.indexIO = indexIO; this.dataSegmentProvider = dataSegmentProvider; - this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; this.memoryParameters = memoryParameters; + this.storageParameters = storageParameters; + this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; } @Override @@ -90,7 +101,8 @@ public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() @Override public File tempDir() { - return context.tempDir(); + // No need to include query ID; each task handles a single query, so there is no ambiguity. + return new File(context.tempDir(), StringUtils.format("stage_%06d", stageId.getStageNumber())); } @Override @@ -128,4 +140,22 @@ public WorkerMemoryParameters memoryParameters() { return memoryParameters; } + + @Override + public Bouncer processorBouncer() + { + return context.injector().getInstance(Bouncer.class); + } + + @Override + public WorkerStorageParameters storageParameters() + { + return storageParameters; + } + + @Override + public void close() throws IOException + { + // Nothing to close. + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java index 30bc75282fa4..2dedaf204ec7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java @@ -52,4 +52,10 @@ public List getAdminPermissions() ) ); } + + @Override + public List getQueryPermissions(String queryId) + { + return getAdminPermissions(); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java index 1bd789df7690..63358467489b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java @@ -24,9 +24,7 @@ import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.inject.Injector; import com.google.inject.Key; -import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.guice.annotations.EscalatedGlobal; -import org.apache.druid.guice.annotations.Self; import org.apache.druid.guice.annotations.Smile; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; import org.apache.druid.indexing.common.TaskToolbox; @@ -35,16 +33,21 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.ControllerClient; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.TaskDataSegmentProvider; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.indexing.client.IndexerControllerClient; import org.apache.druid.msq.indexing.client.IndexerWorkerClient; import org.apache.druid.msq.indexing.client.WorkerChatHandler; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryToolChestWarehouse; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.ServiceLocations; @@ -67,37 +70,49 @@ public class IndexerWorkerContext implements WorkerContext private static final long FREQUENCY_CHECK_MILLIS = 1000; private static final long FREQUENCY_CHECK_JITTER = 30; + private final MSQWorkerTask task; private final TaskToolbox toolbox; private final Injector injector; + private final OverlordClient overlordClient; private final IndexIO indexIO; private final TaskDataSegmentProvider dataSegmentProvider; private final DataServerQueryHandlerFactory dataServerQueryHandlerFactory; private final ServiceClientFactory clientFactory; - - @GuardedBy("this") - private OverlordClient overlordClient; + private final MemoryIntrospector memoryIntrospector; + private final int maxConcurrentStages; @GuardedBy("this") private ServiceLocator controllerLocator; public IndexerWorkerContext( + final MSQWorkerTask task, final TaskToolbox toolbox, final Injector injector, + final OverlordClient overlordClient, final IndexIO indexIO, final TaskDataSegmentProvider dataSegmentProvider, - final DataServerQueryHandlerFactory dataServerQueryHandlerFactory, - final ServiceClientFactory clientFactory + final ServiceClientFactory clientFactory, + final MemoryIntrospector memoryIntrospector, + final DataServerQueryHandlerFactory dataServerQueryHandlerFactory ) { + this.task = task; this.toolbox = toolbox; this.injector = injector; + this.overlordClient = overlordClient; this.indexIO = indexIO; this.dataSegmentProvider = dataSegmentProvider; - this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; this.clientFactory = clientFactory; + this.memoryIntrospector = memoryIntrospector; + this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; + this.maxConcurrentStages = MultiStageQueryContext.getMaxConcurrentStages(QueryContext.of(task.getContext())); } - public static IndexerWorkerContext createProductionInstance(final TaskToolbox toolbox, final Injector injector) + public static IndexerWorkerContext createProductionInstance( + final MSQWorkerTask task, + final TaskToolbox toolbox, + final Injector injector + ) { final IndexIO indexIO = injector.getInstance(IndexIO.class); final SegmentCacheManager segmentCacheManager = @@ -105,28 +120,42 @@ public static IndexerWorkerContext createProductionInstance(final TaskToolbox to .manufacturate(new File(toolbox.getIndexingTmpDir(), "segment-fetch")); final ServiceClientFactory serviceClientFactory = injector.getInstance(Key.get(ServiceClientFactory.class, EscalatedGlobal.class)); + final MemoryIntrospector memoryIntrospector = injector.getInstance(MemoryIntrospector.class); + final OverlordClient overlordClient = + injector.getInstance(OverlordClient.class).withRetryPolicy(StandardRetryPolicy.unlimited()); final ObjectMapper smileMapper = injector.getInstance(Key.get(ObjectMapper.class, Smile.class)); final QueryToolChestWarehouse warehouse = injector.getInstance(QueryToolChestWarehouse.class); return new IndexerWorkerContext( + task, toolbox, injector, + overlordClient, indexIO, - new TaskDataSegmentProvider( - toolbox.getCoordinatorClient(), - segmentCacheManager, - indexIO - ), + new TaskDataSegmentProvider(toolbox.getCoordinatorClient(), segmentCacheManager, indexIO), + serviceClientFactory, + memoryIntrospector, new DataServerQueryHandlerFactory( toolbox.getCoordinatorClient(), serviceClientFactory, smileMapper, warehouse - ), - serviceClientFactory + ) ); } + @Override + public String queryId() + { + return task.getControllerTaskId(); + } + + @Override + public String workerId() + { + return task.getId(); + } + public TaskToolbox toolbox() { return toolbox; @@ -147,7 +176,8 @@ public Injector injector() @Override public void registerWorker(Worker worker, Closer closer) { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + final WorkerChatHandler chatHandler = + new WorkerChatHandler(worker, toolbox.getAuthorizerMapper(), task.getDataSource()); toolbox.getChatHandlerProvider().register(worker.id(), chatHandler, false); closer.register(() -> toolbox.getChatHandlerProvider().unregister(worker.id())); closer.register(() -> { @@ -161,7 +191,7 @@ public void registerWorker(Worker worker, Closer closer) // Register the periodic controller checker final ExecutorService periodicControllerCheckerExec = Execs.singleThreaded("controller-status-checker-%s"); closer.register(periodicControllerCheckerExec::shutdownNow); - final ServiceLocator controllerLocator = makeControllerLocator(worker.task().getControllerTaskId()); + final ServiceLocator controllerLocator = makeControllerLocator(task.getControllerTaskId()); periodicControllerCheckerExec.submit(() -> controllerCheckerRunnable(controllerLocator, worker)); } @@ -218,15 +248,21 @@ public File tempDir() } @Override - public ControllerClient makeControllerClient(String controllerId) + public int maxConcurrentStages() + { + return maxConcurrentStages; + } + + @Override + public ControllerClient makeControllerClient() { - final ServiceLocator locator = makeControllerLocator(controllerId); + final ServiceLocator locator = makeControllerLocator(task.getControllerTaskId()); return new IndexerControllerClient( clientFactory.makeClient( - controllerId, + task.getControllerTaskId(), locator, - new SpecificTaskRetryPolicy(controllerId, StandardRetryPolicy.unlimited()) + new SpecificTaskRetryPolicy(task.getControllerTaskId(), StandardRetryPolicy.unlimited()) ), jsonMapper(), locator @@ -237,37 +273,33 @@ public ControllerClient makeControllerClient(String controllerId) public WorkerClient makeWorkerClient() { // Ignore workerId parameter. The workerId is passed into each method of WorkerClient individually. - return new IndexerWorkerClient(clientFactory, makeOverlordClient(), jsonMapper()); + return new IndexerWorkerClient(clientFactory, overlordClient, jsonMapper()); } @Override - public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) + public FrameContext frameContext(QueryDefinition queryDef, int stageNumber, OutputChannelMode outputChannelMode) { return new IndexerFrameContext( + queryDef.getStageDefinition(stageNumber).getId(), this, indexIO, dataSegmentProvider, dataServerQueryHandlerFactory, - WorkerMemoryParameters.createProductionInstanceForWorker(injector, queryDef, stageNumber) + WorkerMemoryParameters.createProductionInstanceForWorker(injector, queryDef, stageNumber, maxConcurrentStages), + WorkerStorageParameters.createProductionInstance(injector, outputChannelMode) ); } @Override public int threadCount() { - return processorBouncer().getMaxCount(); + return memoryIntrospector.numProcessorsInJvm(); } @Override public DruidNode selfNode() { - return injector.getInstance(Key.get(DruidNode.class, Self.class)); - } - - @Override - public Bouncer processorBouncer() - { - return injector.getInstance(Bouncer.class); + return toolbox.getDruidNode(); } @Override @@ -276,21 +308,13 @@ public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() return dataServerQueryHandlerFactory; } - private synchronized OverlordClient makeOverlordClient() - { - if (overlordClient == null) { - overlordClient = injector.getInstance(OverlordClient.class) - .withRetryPolicy(StandardRetryPolicy.unlimited()); - } - return overlordClient; - } - private synchronized ServiceLocator makeControllerLocator(final String controllerId) { if (controllerLocator == null) { - controllerLocator = new SpecificTaskServiceLocator(controllerId, makeOverlordClient()); + controllerLocator = new SpecificTaskServiceLocator(controllerId, overlordClient); } return controllerLocator; } + } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java index b4d18ea390e9..31b03d63ba6e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java @@ -33,10 +33,13 @@ import org.apache.druid.indexing.common.config.TaskConfig; import org.apache.druid.indexing.common.task.AbstractTask; import org.apache.druid.indexing.common.task.Tasks; +import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.MSQTasks; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerImpl; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.apache.druid.server.security.ResourceAction; import javax.annotation.Nonnull; @@ -48,6 +51,7 @@ public class MSQWorkerTask extends AbstractTask { public static final String TYPE = "query_worker"; + private static final Logger log = new Logger(MSQWorkerTask.class); private final String controllerTaskId; private final int workerNumber; @@ -132,18 +136,25 @@ public boolean isReady(final TaskActionClient taskActionClient) } @Override - public TaskStatus runTask(final TaskToolbox toolbox) throws Exception + public TaskStatus runTask(final TaskToolbox toolbox) { - final WorkerContext context = IndexerWorkerContext.createProductionInstance(toolbox, injector); + final WorkerContext context = IndexerWorkerContext.createProductionInstance(this, toolbox, injector); worker = new WorkerImpl(this, context); - return worker.run(); + + try { + worker.run(); + return TaskStatus.success(context.workerId()); + } + catch (MSQException e) { + return TaskStatus.failure(context.workerId(), MSQFaultUtils.generateMessageWithErrorCode(e.getFault())); + } } @Override public void stopGracefully(TaskConfig taskConfig) { if (worker != null) { - worker.stopGracefully(); + worker.stop(); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java index 81303eb43848..1e31de71a8ac 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java @@ -152,7 +152,7 @@ public void postWorkerWarning(List MSQErrorReports) throws IOExc } @Override - public List getTaskList() throws IOException + public List getWorkerIds() throws IOException { final BytesFullResponseHolder retVal = doRequest( new RequestBuilder(HttpMethod.GET, "/taskList"), diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java index 70d1ab11d380..7c8b86bb9d64 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java @@ -19,310 +19,25 @@ package org.apache.druid.msq.indexing.client; -import com.google.common.collect.ImmutableMap; -import it.unimi.dsi.fastutil.bytes.ByteArrays; -import org.apache.commons.lang.mutable.MutableLong; -import org.apache.druid.frame.file.FrameFileHttpResponseHandler; -import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.Worker; -import org.apache.druid.msq.indexing.MSQWorkerTask; -import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.kernel.WorkOrder; -import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde; +import org.apache.druid.msq.indexing.IndexerResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; import org.apache.druid.segment.realtime.ChatHandler; -import org.apache.druid.segment.realtime.ChatHandlers; -import org.apache.druid.server.security.Action; -import org.apache.druid.utils.CloseableUtils; +import org.apache.druid.segment.realtime.ChatHandlerProvider; +import org.apache.druid.server.security.AuthorizerMapper; -import javax.annotation.Nullable; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.StreamingOutput; -import java.io.IOException; -import java.io.InputStream; - -public class WorkerChatHandler implements ChatHandler +/** + * Subclass of {@link WorkerResource} that implements {@link ChatHandler}, suitable for registration + * with a {@link ChatHandlerProvider}. + */ +public class WorkerChatHandler extends WorkerResource implements ChatHandler { - private static final Logger log = new Logger(WorkerChatHandler.class); - - /** - * Callers must be able to store an entire chunk in memory. It can't be too large. - */ - private static final long CHANNEL_DATA_CHUNK_SIZE = 1_000_000; - - private final Worker worker; - private final MSQWorkerTask task; - private final TaskToolbox toolbox; - - public WorkerChatHandler(TaskToolbox toolbox, Worker worker) - { - this.worker = worker; - this.task = worker.task(); - this.toolbox = toolbox; - } - - /** - * Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data. - *

- * See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API. - */ - @GET - @Path("/channels/{queryId}/{stageNumber}/{partitionNumber}") - @Produces(MediaType.APPLICATION_OCTET_STREAM) - public Response httpGetChannelData( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("partitionNumber") final int partitionNumber, - @QueryParam("offset") final long offset, - @Context final HttpServletRequest req + public WorkerChatHandler( + final Worker worker, + final AuthorizerMapper authorizerMapper, + final String dataSource ) { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - try { - final InputStream inputStream = worker.readChannel(queryId, stageNumber, partitionNumber, offset); - if (inputStream == null) { - return Response.status(Response.Status.NOT_FOUND).build(); - } - - final Response.ResponseBuilder responseBuilder = Response.ok(); - - final byte[] readBuf = new byte[8192]; - final MutableLong bytesReadTotal = new MutableLong(0L); - final int firstRead = inputStream.read(readBuf); - - if (firstRead == -1) { - // Empty read means we're at the end of the channel. Set the last fetch header so the client knows this. - inputStream.close(); - return responseBuilder - .header( - FrameFileHttpResponseHandler.HEADER_LAST_FETCH_NAME, - FrameFileHttpResponseHandler.HEADER_LAST_FETCH_VALUE - ) - .entity(ByteArrays.EMPTY_ARRAY) - .build(); - } - - return Response.ok((StreamingOutput) output -> { - try { - int bytesReadThisCall = firstRead; - do { - final int bytesToWrite = - (int) Math.min(CHANNEL_DATA_CHUNK_SIZE - bytesReadTotal.longValue(), bytesReadThisCall); - output.write(readBuf, 0, bytesToWrite); - bytesReadTotal.add(bytesReadThisCall); - } while (bytesReadTotal.longValue() < CHANNEL_DATA_CHUNK_SIZE - && (bytesReadThisCall = inputStream.read(readBuf)) != -1); - } - catch (Throwable e) { - // Suppress the exception to ensure nothing gets written over the wire once we've sent a 200. The client - // will resume from where it left off. - log.noStackTrace().warn( - e, - "Error writing channel for query [%s] stage [%s] partition [%s] offset [%,d] to [%s]", - queryId, - stageNumber, - partitionNumber, - offset, - req.getRemoteAddr() - ); - } - finally { - CloseableUtils.closeAll(inputStream, output); - } - }).build(); - } - catch (IOException e) { - return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build(); - } - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postWorkOrder} for the client-side code that calls this API. - */ - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - @Path("/workOrder") - public Response httpPostWorkOrder(final WorkOrder workOrder, @Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - worker.postWorkOrder(workOrder); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postResultPartitionBoundaries} for the client-side code that calls this API. - */ - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - @Path("/resultPartitionBoundaries/{queryId}/{stageNumber}") - public Response httpPostResultPartitionBoundaries( - final ClusterByPartitions stagePartitionBoundaries, - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - if (worker.postResultPartitionBoundaries(stagePartitionBoundaries, queryId, stageNumber)) { - return Response.status(Response.Status.ACCEPTED).build(); - } else { - return Response.status(Response.Status.BAD_REQUEST).build(); - } - } - - @POST - @Path("/keyStatistics/{queryId}/{stageNumber}") - @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpFetchKeyStatistics( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper()); - ClusterByStatisticsSnapshot clusterByStatisticsSnapshot; - StageId stageId = new StageId(queryId, stageNumber); - try { - clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId); - if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_OCTET_STREAM) - .entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, clusterByStatisticsSnapshot)) - .build(); - } else { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_JSON) - .entity(clusterByStatisticsSnapshot) - .build(); - } - } - catch (Exception e) { - String errorMessage = StringUtils.format( - "Invalid request for key statistics for query[%s] and stage[%d]", - queryId, - stageNumber - ); - log.error(e, errorMessage); - return Response.status(Response.Status.BAD_REQUEST) - .entity(ImmutableMap.of("error", errorMessage)) - .build(); - } - } - - @POST - @Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}") - @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpFetchKeyStatisticsWithSnapshot( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("timeChunk") final long timeChunk, - @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper()); - ClusterByStatisticsSnapshot snapshotForTimeChunk; - StageId stageId = new StageId(queryId, stageNumber); - try { - snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk); - if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_OCTET_STREAM) - .entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, snapshotForTimeChunk)) - .build(); - } else { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_JSON) - .entity(snapshotForTimeChunk) - .build(); - } - } - catch (Exception e) { - String errorMessage = StringUtils.format( - "Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]", - queryId, - stageNumber, - timeChunk - ); - log.error(e, errorMessage); - return Response.status(Response.Status.BAD_REQUEST) - .entity(ImmutableMap.of("error", errorMessage)) - .build(); - } - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postCleanupStage} for the client-side code that calls this API. - */ - @POST - @Path("/cleanupStage/{queryId}/{stageNumber}") - public Response httpPostCleanupStage( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - worker.postCleanupStage(new StageId(queryId, stageNumber)); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postFinish} for the client-side code that calls this API. - */ - @POST - @Path("/finish") - public Response httpPostFinish(@Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - worker.postFinish(); - return Response.status(Response.Status.ACCEPTED).build(); - } - - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#getCounters} for the client-side code that calls this API. - */ - @GET - @Produces(MediaType.APPLICATION_JSON) - @Path("/counters") - public Response httpGetCounters(@Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - return Response.status(Response.Status.OK).entity(worker.getCounters()).build(); - } - - /** - * Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and - * {@link #httpFetchKeyStatisticsWithSnapshot}. - */ - public enum SketchEncoding - { - /** - * The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}. - */ - OCTET_STREAM, - /** - * The key collector is encoded as json - */ - JSON + super(worker, new IndexerResourcePermissionMapper(dataSource), authorizerMapper); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java index 028f1b5bd48a..de01235447aa 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java @@ -41,6 +41,22 @@ private InputSlices() // No instantiation. } + /** + * Returns all {@link StageInputSlice} from the provided list of input slices. Ignores other types of input slices. + */ + public static List allStageSlices(final List slices) + { + final List retVal = new ArrayList<>(); + + for (final InputSlice slice : slices) { + if (slice instanceof StageInputSlice) { + retVal.add((StageInputSlice) slice); + } + } + + return retVal; + } + /** * Combines all {@link StageInputSlice#getPartitions()} from the input slices that are {@link StageInputSlice}. * Ignores other types of input slices. @@ -49,10 +65,8 @@ public static ReadablePartitions allReadablePartitions(final List sl { final List partitionsList = new ArrayList<>(); - for (final InputSlice slice : slices) { - if (slice instanceof StageInputSlice) { - partitionsList.add(((StageInputSlice) slice).getPartitions()); - } + for (final StageInputSlice slice : allStageSlices(slices)) { + partitionsList.add(slice.getPartitions()); } return ReadablePartitions.combine(partitionsList); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java index 03aa7cd0fe4f..4b68a3bf1b01 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java @@ -31,7 +31,7 @@ import org.apache.druid.data.input.impl.InlineInputSource; import org.apache.druid.data.input.impl.TimestampSpec; import org.apache.druid.java.util.common.DateTimes; -import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterNames; import org.apache.druid.msq.counters.CounterTracker; @@ -53,6 +53,7 @@ import org.apache.druid.timeline.SegmentId; import java.io.File; +import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.function.Consumer; @@ -94,7 +95,7 @@ public ReadableInputs attach( externalInputSlice.getInputSources(), externalInputSlice.getInputFormat(), externalInputSlice.getSignature(), - temporaryDirectory, + new File(temporaryDirectory, String.valueOf(inputNumber)), counters.channel(CounterNames.inputChannel(inputNumber)).setTotalFiles(slice.fileCount()), counters.warnings(), warningPublisher @@ -128,9 +129,13 @@ private static Iterator inputSourceSegmentIterator( ColumnsFilter.all() ); - if (!temporaryDirectory.exists() && !temporaryDirectory.mkdir()) { - throw new ISE("Cannot create temporary directory at [%s]", temporaryDirectory); + try { + FileUtils.mkdirp(temporaryDirectory); } + catch (IOException e) { + throw new RuntimeException(e); + } + return Iterators.transform( inputSources.iterator(), inputSource -> { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java index 7db2fa1a9dd9..da962a9d3931 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java @@ -20,8 +20,11 @@ package org.apache.druid.msq.kernel; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.segment.IndexIO; @@ -30,12 +33,16 @@ import org.apache.druid.segment.incremental.RowIngestionMeters; import org.apache.druid.segment.loading.DataSegmentPusher; +import java.io.Closeable; import java.io.File; /** - * Provides services and objects for the functioning of the frame processors + * Provides services and objects for the functioning of the frame processors. Scoped to a specific stage of a + * specific query, i.e., one {@link WorkOrder}. + * + * Generated by {@link org.apache.druid.msq.exec.WorkerContext#frameContext(QueryDefinition, int, OutputChannelMode)}. */ -public interface FrameContext +public interface FrameContext extends Closeable { SegmentWrangler segmentWrangler(); @@ -59,5 +66,14 @@ public interface FrameContext IndexMergerV9 indexMerger(); + Bouncer processorBouncer(); + WorkerMemoryParameters memoryParameters(); + + WorkerStorageParameters storageParameters(); + + default File tempDir(String name) + { + return new File(tempDir(), name); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java index 632b8a8106dd..b838092ca714 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java @@ -42,6 +42,8 @@ * This separation of decision-making from the "real world" allows the decision-making to live in one, * easy-to-follow place. * + * Not thread-safe. + * * @see org.apache.druid.msq.kernel.controller.ControllerQueryKernel state machine on the controller side */ public class WorkerStageKernel @@ -51,9 +53,10 @@ public class WorkerStageKernel private WorkerStagePhase phase = WorkerStagePhase.NEW; - // We read this variable in the main thread and the netty threads @Nullable - private volatile ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot; + private ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot; + + private boolean doneReadingInput; @Nullable private ClusterByPartitions resultPartitionBoundaries; @@ -107,25 +110,25 @@ public void startReading() public void startPreshuffleWaitingForResultPartitionBoundaries() { - assertPreshuffleStatisticsNeeded(); + assertPreshuffleStatisticsNeeded(true); transitionTo(WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES); } public void startPreshuffleWritingOutput() { - assertPreshuffleStatisticsNeeded(); transitionTo(WorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT); } - public void setResultKeyStatisticsSnapshot(final ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot) + public void setResultKeyStatisticsSnapshot(@Nullable final ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot) { - assertPreshuffleStatisticsNeeded(); + assertPreshuffleStatisticsNeeded(resultKeyStatisticsSnapshot != null); this.resultKeyStatisticsSnapshot = resultKeyStatisticsSnapshot; + this.doneReadingInput = true; } public void setResultPartitionBoundaries(final ClusterByPartitions resultPartitionBoundaries) { - assertPreshuffleStatisticsNeeded(); + assertPreshuffleStatisticsNeeded(true); this.resultPartitionBoundaries = resultPartitionBoundaries; } @@ -134,6 +137,11 @@ public boolean hasResultKeyStatisticsSnapshot() return resultKeyStatisticsSnapshot != null; } + public boolean isDoneReadingInput() + { + return doneReadingInput; + } + public boolean hasResultPartitionBoundaries() { return resultPartitionBoundaries != null; @@ -152,10 +160,10 @@ public ClusterByPartitions getResultPartitionBoundaries() @Nullable public Object getResultObject() { - if (phase == WorkerStagePhase.RESULTS_READY || phase == WorkerStagePhase.FINISHED) { + if (phase == WorkerStagePhase.RESULTS_COMPLETE) { return resultObject; } else { - throw new ISE("Results are not ready yet"); + throw new ISE("Results are not ready in phase[%s]", phase); } } @@ -174,7 +182,7 @@ public void setResultsComplete(Object resultObject) throw new NullPointerException("resultObject must not be null"); } - transitionTo(WorkerStagePhase.RESULTS_READY); + transitionTo(WorkerStagePhase.RESULTS_COMPLETE); this.resultObject = resultObject; } @@ -196,16 +204,18 @@ public void fail(Throwable t) } } - public boolean addPostedResultsComplete(Pair stageIdAndWorkerNumber) + public boolean addPostedResultsComplete(StageId stageId, int workerNumber) { - return postedResultsComplete.add(stageIdAndWorkerNumber); + return postedResultsComplete.add(Pair.of(stageId, workerNumber)); } - private void assertPreshuffleStatisticsNeeded() + private void assertPreshuffleStatisticsNeeded(final boolean delivered) { - if (!workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { + if (delivered != workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { throw new ISE( - "Result partitioning is not necessary for stage [%s]", + "Result key statistics %s, but %s, for stage[%s]", + delivered ? "delivered" : "not delivered", + workOrder.getStageDefinition().mustGatherResultKeyStatistics() ? "expected" : "not expected", workOrder.getStageDefinition().getId() ); } @@ -222,7 +232,12 @@ private void transitionTo(final WorkerStagePhase newPhase) ); phase = newPhase; } else { - throw new IAE("Cannot transition from [%s] to [%s]", phase, newPhase); + throw new IAE( + "Cannot transition stage[%s] from[%s] to[%s]", + workOrder.getStageDefinition().getId(), + phase, + newPhase + ); } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java index f54aa52349ea..7e3ac5c7cac4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java @@ -54,11 +54,12 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { - return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES; + return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES /* if globally sorting */ + || priorPhase == READING_INPUT /* if locally sorting */; } }, - RESULTS_READY { + RESULTS_COMPLETE { @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { @@ -70,7 +71,7 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { - return priorPhase == RESULTS_READY; + return priorPhase.compareTo(FINISHED) < 0; } }, @@ -84,4 +85,24 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) }; public abstract boolean canTransitionFrom(WorkerStagePhase priorPhase); + + /** + * Whether this phase indicates that the stage is no longer running. + */ + public boolean isTerminal() + { + return this == FINISHED || this == FAILED; + } + + /** + * Whether this phase indicates a stage is running and consuming its full complement of resources. + * + * There are still some resources that can be consumed by stages that are not running. For example, in the + * {@link #FINISHED} state, stages can still have data on disk that has not been cleaned-up yet, some pointers + * to that data that still reside in memory, and some counters in memory available for collection by the controller. + */ + public boolean isRunning() + { + return this != NEW && this != RESULTS_COMPLETE && this != FINISHED && this != FAILED; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java index 8d0fba72a216..fd1a0323d0fb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -97,7 +97,7 @@ public ListenableFuture fetchClusterByStatisticsSna "/keyStatistics/%s/%d?sketchEncoding=%s", StringUtils.urlEncode(stageId.getQueryId()), stageId.getStageNumber(), - WorkerChatHandler.SketchEncoding.OCTET_STREAM + WorkerResource.SketchEncoding.OCTET_STREAM ); return getClient(workerId).asyncRequest( @@ -118,7 +118,7 @@ public ListenableFuture fetchClusterByStatisticsSna StringUtils.urlEncode(stageId.getQueryId()), stageId.getStageNumber(), timeChunk, - WorkerChatHandler.SketchEncoding.OCTET_STREAM + WorkerResource.SketchEncoding.OCTET_STREAM ); return getClient(workerId).asyncRequest( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java index d3e9eefa86d2..c6ddb5cd582b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java @@ -82,6 +82,28 @@ public Response httpPostPartialKeyStatistics( return Response.status(Response.Status.ACCEPTED).build(); } + /** + * Used by subtasks to inform the controller that they are done reading their input, in cases where they would + * not be calling {@link #httpPostPartialKeyStatistics(Object, String, int, int, HttpServletRequest)}. + * + * See {@link ControllerClient#postDoneReadingInput(StageId, int)} for the client-side code that calls this API. + */ + @POST + @Path("/doneReadingInput/{queryId}/{stageNumber}/{workerNumber}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostDoneReadingInput( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("workerNumber") final int workerNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.doneReadingInput(stageNumber, workerNumber); + return Response.status(Response.Status.ACCEPTED).build(); + } + /** * Used by subtasks to post system errors. Note that the errors are organized by taskId, not by query/stage/worker, * because system errors are associated with a task rather than a specific query/stage/worker execution context. @@ -166,7 +188,7 @@ public Response httpPostResultsComplete( } /** - * See {@link ControllerClient#getTaskList()} for the client-side code that calls this API. + * See {@link ControllerClient#getWorkerIds} for the client-side code that calls this API. */ @GET @Path("/taskList") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java index 30a8179fe0f0..8820b4ead5a0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java @@ -47,4 +47,20 @@ public static void authorizeAdminRequest( throw new ForbiddenException(access.toString()); } } + + public static void authorizeQueryRequest( + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper, + final HttpServletRequest request, + final String queryId + ) + { + final List resourceActions = permissionMapper.getQueryPermissions(queryId); + + Access access = AuthorizationUtils.authorizeAllResourceActions(request, resourceActions, authorizerMapper); + + if (!access.isAllowed()) { + throw new ForbiddenException(access.toString()); + } + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java index 8c79f4fa0e05..0a7fb874f6d1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java @@ -23,11 +23,9 @@ import java.util.List; -/** - * Provides HTTP resources such as {@link ControllerResource} with information about which permissions are needed - * for requests. - */ public interface ResourcePermissionMapper { List getAdminPermissions(); + + List getQueryPermissions(String queryId); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java new file mode 100644 index 000000000000..88dfddaeb7ce --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.frame.file.FrameFileHttpResponseHandler; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde; +import org.apache.druid.server.security.AuthorizerMapper; +import org.apache.druid.utils.CloseableUtils; + +import javax.annotation.Nullable; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.StreamingOutput; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +public class WorkerResource +{ + private static final Logger log = new Logger(WorkerResource.class); + + /** + * Callers must be able to store an entire chunk in memory. It can't be too large. + */ + private static final long CHANNEL_DATA_CHUNK_SIZE = 1_000_000; + private static final long GET_CHANNEL_DATA_TIMEOUT = 30_000L; + + protected final Worker worker; + protected final ResourcePermissionMapper permissionMapper; + protected final AuthorizerMapper authorizerMapper; + + public WorkerResource( + final Worker worker, + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper + ) + { + this.worker = worker; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + } + + /** + * Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data. + *

+ * See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API. + */ + @GET + @Path("/channels/{queryId}/{stageNumber}/{partitionNumber}") + @Produces(MediaType.APPLICATION_OCTET_STREAM) + public Response httpGetChannelData( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("partitionNumber") final int partitionNumber, + @QueryParam("offset") final long offset, + @Context final HttpServletRequest req + ) throws IOException + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + + final ListenableFuture dataFuture = + worker.readChannel(new StageId(queryId, stageNumber), partitionNumber, offset); + + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(GET_CHANNEL_DATA_TIMEOUT); + asyncContext.addListener( + new AsyncListener() + { + @Override + public void onComplete(AsyncEvent event) + { + } + + @Override + public void onTimeout(AsyncEvent event) + { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.setStatus(HttpServletResponse.SC_OK); + event.getAsyncContext().complete(); + } + + @Override + public void onError(AsyncEvent event) + { + } + + @Override + public void onStartAsync(AsyncEvent event) + { + } + } + ); + + // Save these items, since "req" becomes inaccessible in future exception handlers. + final String remoteAddr = req.getRemoteAddr(); + final String requestURI = req.getRequestURI(); + + Futures.addCallback( + dataFuture, + new FutureCallback() + { + @Override + public void onSuccess(final InputStream inputStream) + { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + + try (final OutputStream outputStream = response.getOutputStream()) { + if (inputStream == null) { + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + } else { + response.setStatus(HttpServletResponse.SC_OK); + response.setContentType(MediaType.APPLICATION_OCTET_STREAM); + + final byte[] readBuf = new byte[8192]; + final int firstRead = inputStream.read(readBuf); + + if (firstRead == -1) { + // Empty read means we're at the end of the channel. + // Set the last fetch header so the client knows this. + response.setHeader( + FrameFileHttpResponseHandler.HEADER_LAST_FETCH_NAME, + FrameFileHttpResponseHandler.HEADER_LAST_FETCH_VALUE + ); + } else { + long bytesReadTotal = 0; + int bytesReadThisCall = firstRead; + do { + final int bytesToWrite = + (int) Math.min(CHANNEL_DATA_CHUNK_SIZE - bytesReadTotal, bytesReadThisCall); + outputStream.write(readBuf, 0, bytesToWrite); + bytesReadTotal += bytesReadThisCall; + } while (bytesReadTotal < CHANNEL_DATA_CHUNK_SIZE + && (bytesReadThisCall = inputStream.read(readBuf)) != -1); + } + } + } + catch (Exception e) { + log.noStackTrace().warn(e, "Could not respond to request from[%s] to[%s]", remoteAddr, requestURI); + } + finally { + CloseableUtils.closeAndSuppressExceptions(inputStream, e -> log.warn("Failed to close output channel")); + asyncContext.complete(); + } + } + + @Override + public void onFailure(Throwable e) + { + if (!dataFuture.isCancelled()) { + try { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + asyncContext.complete(); + } + catch (Exception e2) { + e.addSuppressed(e2); + } + + log.noStackTrace().warn(e, "Request failed from[%s] to[%s]", remoteAddr, requestURI); + } + } + }, + Execs.directExecutor() + ); + + return null; + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postWorkOrder} for the client-side code that calls this API. + */ + @POST + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Path("/workOrder") + public Response httpPostWorkOrder(final WorkOrder workOrder, @Context final HttpServletRequest req) + { + final String queryId = workOrder.getQueryDefinition().getQueryId(); + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + worker.postWorkOrder(workOrder); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postResultPartitionBoundaries} for the client-side code that calls this API. + */ + @POST + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Path("/resultPartitionBoundaries/{queryId}/{stageNumber}") + public Response httpPostResultPartitionBoundaries( + final ClusterByPartitions stagePartitionBoundaries, + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + if (worker.postResultPartitionBoundaries(new StageId(queryId, stageNumber), stagePartitionBoundaries)) { + return Response.status(Response.Status.ACCEPTED).build(); + } else { + return Response.status(Response.Status.BAD_REQUEST).build(); + } + } + + @POST + @Path("/keyStatistics/{queryId}/{stageNumber}") + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) + public Response httpFetchKeyStatistics( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + ClusterByStatisticsSnapshot clusterByStatisticsSnapshot; + StageId stageId = new StageId(queryId, stageNumber); + try { + clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId); + if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_OCTET_STREAM) + .entity( + (StreamingOutput) output -> + ClusterByStatisticsSnapshotSerde.serialize(output, clusterByStatisticsSnapshot) + ) + .build(); + } else { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_JSON) + .entity(clusterByStatisticsSnapshot) + .build(); + } + } + catch (Exception e) { + String errorMessage = StringUtils.format( + "Invalid request for key statistics for query[%s] and stage[%d]", + queryId, + stageNumber + ); + log.error(e, errorMessage); + return Response.status(Response.Status.BAD_REQUEST) + .entity(ImmutableMap.of("error", errorMessage)) + .build(); + } + } + + @POST + @Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}") + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) + public Response httpFetchKeyStatisticsWithSnapshot( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("timeChunk") final long timeChunk, + @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + ClusterByStatisticsSnapshot snapshotForTimeChunk; + StageId stageId = new StageId(queryId, stageNumber); + try { + snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk); + if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_OCTET_STREAM) + .entity( + (StreamingOutput) output -> + ClusterByStatisticsSnapshotSerde.serialize(output, snapshotForTimeChunk) + ) + .build(); + } else { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_JSON) + .entity(snapshotForTimeChunk) + .build(); + } + } + catch (Exception e) { + String errorMessage = StringUtils.format( + "Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]", + queryId, + stageNumber, + timeChunk + ); + log.error(e, errorMessage); + return Response.status(Response.Status.BAD_REQUEST) + .entity(ImmutableMap.of("error", errorMessage)) + .build(); + } + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postCleanupStage} for the client-side code that calls this API. + */ + @POST + @Path("/cleanupStage/{queryId}/{stageNumber}") + public Response httpPostCleanupStage( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + worker.postCleanupStage(new StageId(queryId, stageNumber)); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postFinish} for the client-side code that calls this API. + */ + @POST + @Path("/finish") + public Response httpPostFinish(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + worker.postFinish(); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#getCounters} for the client-side code that calls this API. + */ + @GET + @Produces({MediaType.APPLICATION_JSON + "; qs=0.9", SmileMediaTypes.APPLICATION_JACKSON_SMILE + "; qs=0.1"}) + @Path("/counters") + public Response httpGetCounters(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return Response.status(Response.Status.OK).entity(worker.getCounters()).build(); + } + + /** + * Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and + * {@link #httpFetchKeyStatisticsWithSnapshot}. + */ + public enum SketchEncoding + { + /** + * The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}. + */ + OCTET_STREAM, + /** + * The key collector is encoded as json + */ + JSON + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java new file mode 100644 index 000000000000..37595050c819 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.input; + +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * Meta-factory that wraps {@link #inputChannelFactoryProvider}, and can create various other kinds of factories. + */ +public class MetaInputChannelFactory implements InputChannelFactory +{ + private final Int2ObjectMap stageOutputModeMap; + private final Function inputChannelFactoryProvider; + private final Map inputChannelFactoryMap = new HashMap<>(); + + public MetaInputChannelFactory( + final Int2ObjectMap stageOutputModeMap, + final Function inputChannelFactoryProvider + ) + { + this.stageOutputModeMap = stageOutputModeMap; + this.inputChannelFactoryProvider = inputChannelFactoryProvider; + } + + /** + * Create a meta-factory. + * + * @param slices stage slices from {@link WorkOrder#getInputs()} + * @param defaultOutputChannelMode mode to use when {@link StageInputSlice#getOutputChannelMode()} is null; i.e., + * when running with an older controller + * @param inputChannelFactoryProvider provider of {@link InputChannelFactory} for various {@link OutputChannelMode} + */ + public static MetaInputChannelFactory create( + final List slices, + final OutputChannelMode defaultOutputChannelMode, + final Function inputChannelFactoryProvider + ) + { + final Int2ObjectMap stageOutputModeMap = new Int2ObjectOpenHashMap<>(); + + for (final StageInputSlice slice : slices) { + final OutputChannelMode newMode; + + if (slice.getOutputChannelMode() != null) { + newMode = slice.getOutputChannelMode(); + } else { + newMode = defaultOutputChannelMode; + } + + final OutputChannelMode prevMode = stageOutputModeMap.putIfAbsent( + slice.getStageNumber(), + newMode + ); + + if (prevMode != null && prevMode != newMode) { + throw new ISE( + "Inconsistent output modes for stage[%s], got[%s] and[%s]", + slice.getStageNumber(), + prevMode, + newMode + ); + } + } + + return new MetaInputChannelFactory(stageOutputModeMap, inputChannelFactoryProvider); + } + + @Override + public ReadableFrameChannel openChannel( + final StageId stageId, + final int workerNumber, + final int partitionNumber + ) throws IOException + { + final OutputChannelMode outputChannelMode = stageOutputModeMap.get(stageId.getStageNumber()); + + if (outputChannelMode == null) { + throw new ISE("No output mode for stageNumber[%s]", stageId.getStageNumber()); + } + + return inputChannelFactoryMap.computeIfAbsent(outputChannelMode, inputChannelFactoryProvider) + .openChannel(stageId, workerNumber, partitionNumber); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java new file mode 100644 index 000000000000..08c7176b7c2b --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.input; + +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.shuffle.output.StageOutputHolder; + +import java.io.IOException; +import java.util.List; +import java.util.function.Supplier; + +/** + * An {@link InputChannelFactory} that loads data locally when possible, and otherwise connects directly to other + * workers. Used when durable shuffle storage is off. + */ +public class WorkerOrLocalInputChannelFactory implements InputChannelFactory +{ + private final String myId; + private final Supplier> workerIdsSupplier; + private final InputChannelFactory workerInputChannelFactory; + private final StageOutputHolderProvider stageOutputHolderProvider; + + public WorkerOrLocalInputChannelFactory( + final String myId, + final Supplier> workerIdsSupplier, + final InputChannelFactory workerInputChannelFactory, + final StageOutputHolderProvider stageOutputHolderProvider + ) + { + this.myId = myId; + this.workerIdsSupplier = workerIdsSupplier; + this.workerInputChannelFactory = workerInputChannelFactory; + this.stageOutputHolderProvider = stageOutputHolderProvider; + } + + @Override + public ReadableFrameChannel openChannel(StageId stageId, int workerNumber, int partitionNumber) throws IOException + { + final String taskId = workerIdsSupplier.get().get(workerNumber); + if (taskId.equals(myId)) { + return stageOutputHolderProvider.getHolder(stageId, partitionNumber).readLocally(); + } else { + return workerInputChannelFactory.openChannel(stageId, workerNumber, partitionNumber); + } + } + + public interface StageOutputHolderProvider + { + StageOutputHolder getHolder(StageId stageId, int partitionNumber); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java new file mode 100644 index 000000000000..4767d818dea4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +/** + * Input stream based on a list of byte arrays. + */ +public class ByteChunksInputStream extends InputStream +{ + private final List chunks; + private int chunkNum; + private int positionWithinChunk; + + /** + * Create a new stream wrapping a list of chunks. + * + * @param chunks byte arrays + * @param positionWithinFirstChunk starting position within the first byte array + */ + public ByteChunksInputStream(final List chunks, final int positionWithinFirstChunk) + { + this.chunks = chunks; + this.positionWithinChunk = positionWithinFirstChunk; + } + + @Override + public int read() throws IOException + { + if (chunkNum >= chunks.size()) { + return -1; + } else { + final byte[] currentChunk = chunks.get(chunkNum); + final byte b = currentChunk[positionWithinChunk++]; + + if (positionWithinChunk == currentChunk.length) { + chunkNum++; + positionWithinChunk = 0; + } + + return b & 0xFF; + } + } + + @Override + public int read(byte[] b) throws IOException + { + return read(b, 0, b.length); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException + { + if (len == 0) { + return 0; + } else if (chunkNum >= chunks.size()) { + return -1; + } else { + int r = 0; + + while (r < len && chunkNum < chunks.size()) { + final byte[] currentChunk = chunks.get(chunkNum); + int toReadFromCurrentChunk = Math.min(len - r, currentChunk.length - positionWithinChunk); + System.arraycopy(currentChunk, positionWithinChunk, b, off + r, toReadFromCurrentChunk); + r += toReadFromCurrentChunk; + positionWithinChunk += toReadFromCurrentChunk; + if (positionWithinChunk == currentChunk.length) { + chunkNum++; + positionWithinChunk = 0; + } + } + + return r; + } + } + + @Override + public void close() throws IOException + { + chunkNum = chunks.size(); + positionWithinChunk = 0; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java new file mode 100644 index 000000000000..eba835cd544b --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.frame.channel.ByteTracker; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.file.FrameFileWriter; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelUtils; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.LinkedList; + +/** + * Reader for {@link ReadableFrameChannel}. + * + * Because this reader returns an underlying channel directly, it must only be used when it is certain that + * only a single consumer exists, i.e., when using output mode {@link OutputChannelMode#MEMORY}. See + * {@link ControllerQueryKernelUtils#canUseMemoryOutput} for the code that ensures that there is only a single + * consumer in the in-memory case. + */ +public class ChannelStageOutputReader implements StageOutputReader +{ + enum State + { + INIT, + LOCAL, + REMOTE, + CLOSED + } + + private final ReadableFrameChannel channel; + private final FrameFileWriter writer; + + /** + * Pair of chunk size + chunk InputStream. + */ + private final LinkedList chunks = new LinkedList<>(); + + /** + * State of this reader. + */ + private State state = State.INIT; + + /** + * Position within the overall stream. + */ + private long cursor; + + /** + * Offset of the first chunk in {@link #chunks} which corresponds to {@link #cursor}. + */ + private int positionWithinFirstChunk; + + /** + * Whether {@link FrameFileWriter#close()} is called on {@link #writer}. + */ + private boolean didCloseWriter; + + public ChannelStageOutputReader(final ReadableFrameChannel channel) + { + this.channel = channel; + this.writer = FrameFileWriter.open(new ChunkAcceptor(), null, ByteTracker.unboundedTracker()); + } + + @Override + public synchronized ListenableFuture readRemotelyFrom(final long offset) + { + if (state == State.INIT) { + state = State.REMOTE; + } else if (state == State.LOCAL) { + throw new ISE("Cannot read both remotely and locally"); + } else if (state == State.CLOSED) { + throw new ISE("Closed"); + } + + if (offset < cursor) { + return Futures.immediateFailedFuture( + new ISE("Offset[%,d] no longer available, current cursor is[%,d]", offset, cursor)); + } + + while (chunks.isEmpty() || offset > cursor) { + // Fetch additional chunks if needed. + if (chunks.isEmpty()) { + if (didCloseWriter) { + if (offset == cursor) { + return Futures.immediateFuture(new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY)); + } else { + throw DruidException.defensive( + "Channel finished but cursor[%,d] does not match requested offset[%,d]", + cursor, + offset + ); + } + } else if (channel.isFinished()) { + try { + writer.close(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + + didCloseWriter = true; + continue; + } else if (channel.canRead()) { + try { + writer.writeFrame(channel.read(), FrameFileWriter.NO_PARTITION); + } + catch (Exception e) { + try { + writer.abort(); + } + catch (IOException e2) { + e.addSuppressed(e2); + } + + throw new RuntimeException(e); + } + } else { + return FutureUtils.transformAsync(channel.readabilityFuture(), ignored -> readRemotelyFrom(offset)); + } + } + + // Remove first chunk if it is no longer needed. (offset is entirely past it.) + final byte[] chunk = chunks.peek(); + final long amountToAdvance = Math.min(offset - cursor, chunk.length - positionWithinFirstChunk); + cursor += amountToAdvance; + positionWithinFirstChunk += Ints.checkedCast(amountToAdvance); + + if (positionWithinFirstChunk == chunk.length) { + chunks.poll(); + positionWithinFirstChunk = 0; + } + } + + if (chunks.isEmpty() || offset != cursor) { + throw DruidException.defensive( + "Expected cursor[%,d] to be caught up to offset[%,d] by this point, and to have nonzero chunks", + cursor, + offset + ); + } + + return Futures.immediateFuture(new ByteChunksInputStream(ImmutableList.copyOf(chunks), positionWithinFirstChunk)); + } + + @Override + public synchronized ReadableFrameChannel readLocally() + { + if (state == State.INIT) { + state = State.LOCAL; + return channel; + } else if (state == State.REMOTE) { + throw new ISE("Cannot read both remotely and locally"); + } else if (state == State.LOCAL) { + throw new ISE("Cannot read channel multiple times"); + } else { + assert state == State.CLOSED; + throw new ISE("Closed"); + } + } + + @Override + public synchronized void close() throws IOException + { + // Call channel.close() unless readLocally() has been called. In that case, we expect the caller to close it. + if (state != State.LOCAL) { + channel.close(); + } + } + + /** + * Input stream that can have bytes appended to it, and that can have bytes acknowledged. + */ + private class ChunkAcceptor implements WritableByteChannel + { + private boolean open = true; + + @Override + public int write(final ByteBuffer src) throws IOException + { + if (!open) { + throw new IOException("Closed"); + } + + final int len = src.remaining(); + if (len > 0) { + final byte[] bytes = new byte[len]; + src.get(bytes); + chunks.add(bytes); + } + + return len; + } + + @Override + public boolean isOpen() + { + return open; + } + + @Override + public void close() + { + open = false; + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java new file mode 100644 index 000000000000..37f01a7a2544 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.file.FrameFile; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; + +/** + * Reader for {@link FrameFile} on disk. + */ +public class FileStageOutputReader implements StageOutputReader +{ + private final FrameFile frameFile; + + public FileStageOutputReader(FrameFile frameFile) + { + this.frameFile = frameFile; + } + + @Override + public ListenableFuture readRemotelyFrom(long offset) + { + try { + final RandomAccessFile randomAccessFile = new RandomAccessFile(frameFile.file(), "r"); + + if (offset >= randomAccessFile.length()) { + randomAccessFile.close(); + return Futures.immediateFuture(new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY)); + } else { + randomAccessFile.seek(offset); + return Futures.immediateFuture(Channels.newInputStream(randomAccessFile.getChannel())); + } + } + catch (Exception e) { + return Futures.immediateFailedFuture(e); + } + } + + @Override + public ReadableFrameChannel readLocally() + { + return new ReadableFileFrameChannel(frameFile.newReference()); + } + + @Override + public void close() throws IOException + { + frameFile.close(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java new file mode 100644 index 000000000000..37500ae5eafd --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; + +import java.util.NoSuchElementException; + +public class FutureReadableFrameChannel implements ReadableFrameChannel +{ + private static final Logger log = new Logger(FutureReadableFrameChannel.class); + + private final ListenableFuture channelFuture; + private ReadableFrameChannel channel; + + public FutureReadableFrameChannel(final ListenableFuture channelFuture) + { + this.channelFuture = channelFuture; + } + + @Override + public boolean isFinished() + { + if (populateChannel()) { + return channel.isFinished(); + } else { + return false; + } + } + + @Override + public boolean canRead() + { + if (populateChannel()) { + return channel.canRead(); + } else { + return false; + } + } + + @Override + public Frame read() + { + if (populateChannel()) { + return channel.read(); + } else { + throw new NoSuchElementException(); + } + } + + @Override + public ListenableFuture readabilityFuture() + { + if (populateChannel()) { + return channel.readabilityFuture(); + } else { + return FutureUtils.transformAsync(channelFuture, ignored -> readabilityFuture()); + } + } + + @Override + public void close() + { + if (populateChannel()) { + channel.close(); + } else { + channelFuture.cancel(true); + channelFuture.addListener( + () -> { + final ReadableFrameChannel channel; + + try { + channel = FutureUtils.getUncheckedImmediately(channelFuture); + } + catch (Throwable ignored) { + // Some error happened while creating the channel. Suppress it. + return; + } + + try { + channel.close(); + } + catch (Throwable t) { + log.noStackTrace().warn(t, "Failed to close channel"); + } + }, + Execs.directExecutor() + ); + } + } + + private boolean populateChannel() + { + if (channel != null) { + return true; + } else if (channelFuture.isDone()) { + channel = FutureUtils.getUncheckedImmediately(channelFuture); + return true; + } else { + return false; + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java new file mode 100644 index 000000000000..3841cc7d7aee --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.frame.channel.ByteTracker; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.file.FrameFileWriter; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; + +/** + * Reader for empty channel. + */ +public class NilStageOutputReader implements StageOutputReader +{ + public static final NilStageOutputReader INSTANCE = new NilStageOutputReader(); + + private static final byte[] EMPTY_FRAME_FILE; + + static { + try { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + FrameFileWriter.open(Channels.newChannel(baos), null, ByteTracker.unboundedTracker()).close(); + EMPTY_FRAME_FILE = baos.toByteArray(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public ListenableFuture readRemotelyFrom(final long offset) + { + final ByteArrayInputStream in = new ByteArrayInputStream(EMPTY_FRAME_FILE); + + //noinspection ResultOfMethodCallIgnored: OK to ignore since "skip" always works for ByteArrayInputStream. + in.skip(offset); + + return Futures.immediateFuture(in); + } + + @Override + public ReadableFrameChannel readLocally() + { + return ReadableNilFrameChannel.INSTANCE; + } + + @Override + public void close() throws IOException + { + // Nothing to do. + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java new file mode 100644 index 000000000000..215facea3633 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.utils.CloseableUtils; + +import java.io.Closeable; +import java.io.InputStream; + +/** + * Container for a {@link StageOutputReader}, which is used to read the output of a stage. + */ +public class StageOutputHolder implements Closeable +{ + private final SettableFuture channelFuture; + private final ListenableFuture readerFuture; + + public StageOutputHolder() + { + this.channelFuture = SettableFuture.create(); + this.readerFuture = FutureUtils.transform(channelFuture, StageOutputHolder::createReader); + } + + public ListenableFuture readRemotelyFrom(final long offset) + { + return FutureUtils.transformAsync(readerFuture, reader -> reader.readRemotelyFrom(offset)); + } + + public ReadableFrameChannel readLocally() + { + return new FutureReadableFrameChannel(FutureUtils.transform(readerFuture, StageOutputReader::readLocally)); + } + + public void setChannel(final ReadableFrameChannel channel) + { + if (!channelFuture.set(channel)) { + if (FutureUtils.getUncheckedImmediately(channelFuture) == null) { + throw new ISE("Closed"); + } else { + throw new ISE("Channel already set"); + } + } + } + + @Override + public void close() + { + channelFuture.set(null); + + final StageOutputReader reader; + + try { + reader = FutureUtils.getUnchecked(readerFuture, true); + } + catch (Throwable e) { + // Error creating the reader, nothing to close. Suppress. + return; + } + + if (reader != null) { + CloseableUtils.closeAndWrapExceptions(reader); + } + } + + private static StageOutputReader createReader(final ReadableFrameChannel channel) + { + if (channel == null) { + // Happens if close() was called before the channel resolved. + throw new ISE("Closed"); + } + + if (channel instanceof ReadableNilFrameChannel) { + return NilStageOutputReader.INSTANCE; + } + + if (channel instanceof ReadableFileFrameChannel) { + // Optimized implementation when reading an entire file. + final ReadableFileFrameChannel fileChannel = (ReadableFileFrameChannel) channel; + + if (fileChannel.isEntireFile()) { + final FrameFile frameFile = fileChannel.newFrameFileReference(); + + // Close original channel, so we don't leak a frame file reference. + channel.close(); + + return new FileStageOutputReader(frameFile); + } + } + + // Generic implementation for any other type of channel. + return new ChannelStageOutputReader(channel); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java new file mode 100644 index 000000000000..bad319135158 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.shuffle.input.WorkerOrLocalInputChannelFactory; + +import java.io.Closeable; +import java.io.InputStream; + +/** + * Interface for remotely reading output channels for a particular stage. Each instance of this interface represents a + * stream from a single {@link org.apache.druid.msq.kernel.StagePartition} in + * {@link org.apache.druid.frame.file.FrameFile} format. + */ +public interface StageOutputReader extends Closeable +{ + /** + * Returns an {@link InputStream} starting from a particular point in the + * {@link org.apache.druid.frame.file.FrameFile}. Length of the stream is implementation-dependent; it may or may + * not go all the way to the end of the file. Zero-length stream indicates EOF. Any nonzero length means you should + * call this method again with a higher offset. + * + * @param offset offset into the frame file + * + * @see org.apache.druid.msq.exec.WorkerImpl#readChannel(StageId, int, long) + */ + ListenableFuture readRemotelyFrom(long offset); + + /** + * Returns a {@link ReadableFrameChannel} for local reading. + * + * @see WorkerOrLocalInputChannelFactory#openChannel(StageId, int, int) + */ + ReadableFrameChannel readLocally(); +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java index d3a67fdd659c..1b2eebe7742e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java @@ -28,6 +28,7 @@ import org.apache.druid.msq.sql.MSQTaskSqlEngine; import org.apache.druid.msq.test.CalciteMSQTestsHelper; import org.apache.druid.msq.test.ExtractResultsFactory; +import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestOverlordServiceClient; import org.apache.druid.msq.test.MSQTestTaskActionClient; import org.apache.druid.msq.test.VerifyMSQSupportedNativeQueriesPredicate; @@ -63,15 +64,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java deleted file mode 100644 index 171f476ebfe0..000000000000 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.msq.exec; - -import org.apache.druid.java.util.common.ISE; -import org.apache.druid.msq.indexing.MSQWorkerTask; -import org.apache.druid.msq.kernel.StageId; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -import java.util.HashMap; - - -@RunWith(MockitoJUnitRunner.class) -public class WorkerImplTest -{ - @Mock - WorkerContext workerContext; - - @Test - public void testFetchStatsThrows() - { - WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0), workerContext, WorkerStorageParameters.createInstanceForTests(Long.MAX_VALUE)); - Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshot(new StageId("xx", 1))); - } - - @Test - public void testFetchStatsWithTimeChunkThrows() - { - WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0), workerContext, WorkerStorageParameters.createInstanceForTests(Long.MAX_VALUE)); - Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshotForTimeChunk(new StageId("xx", 1), 1L)); - } - -} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java index 29614fc07347..d4dd4b47e688 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java @@ -32,34 +32,54 @@ public class WorkerMemoryParametersTest @Test public void test_oneWorkerInJvm_alone() { - Assert.assertEquals(params(335_500_000, 1, 41, 75_000_000), create(1_000_000_000, 1, 1, 1, 0, 0)); - Assert.assertEquals(params(223_000_000, 2, 13, 75_000_000), create(1_000_000_000, 1, 2, 1, 0, 0)); - Assert.assertEquals(params(133_000_000, 4, 3, 75_000_000), create(1_000_000_000, 1, 4, 1, 0, 0)); - Assert.assertEquals(params(73_000_000, 3, 2, 75_000_000), create(1_000_000_000, 1, 8, 1, 0, 0)); - Assert.assertEquals(params(49_923_076, 2, 2, 75_000_000), create(1_000_000_000, 1, 12, 1, 0, 0)); + Assert.assertEquals(params(335_500_000, 1, 41, 75_000_000), create(1_000_000_000, 1, 1, 1, 1, 0, 0)); + Assert.assertEquals(params(223_000_000, 2, 13, 75_000_000), create(1_000_000_000, 1, 2, 1, 1, 0, 0)); + Assert.assertEquals(params(133_000_000, 4, 3, 75_000_000), create(1_000_000_000, 1, 4, 1, 1, 0, 0)); + Assert.assertEquals(params(73_000_000, 3, 2, 75_000_000), create(1_000_000_000, 1, 8, 1, 1, 0, 0)); + Assert.assertEquals(params(49_923_076, 2, 2, 75_000_000), create(1_000_000_000, 1, 12, 1, 1, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(1_000_000_000, 1, 32, 1, 0, 0) + () -> create(1_000_000_000, 1, 32, 1, 1, 0, 0) ); Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); - final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 0, 0)) + final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 1, 0, 0)) .getFault(); Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); + } + + @Test + public void test_oneWorkerInJvm_alone_twoConcurrentStages() + { + Assert.assertEquals(params(166_750_000, 1, 20, 37_500_000), create(1_000_000_000, 1, 1, 2, 1, 0, 0)); + Assert.assertEquals(params(110_500_000, 2, 6, 37_500_000), create(1_000_000_000, 1, 2, 2, 1, 0, 0)); + Assert.assertEquals(params(65_500_000, 2, 3, 37_500_000), create(1_000_000_000, 1, 4, 2, 1, 0, 0)); + Assert.assertEquals(params(35_500_000, 1, 3, 37_500_000), create(1_000_000_000, 1, 8, 2, 1, 0, 0)); + + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> create(1_000_000_000, 1, 12, 2, 1, 0, 0) + ); + + Assert.assertEquals(new NotEnoughMemoryFault(1_736_034_666, 1_000_000_000, 750_000_000, 1, 12), e.getFault()); + + final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 2, 1, 0, 0)) + .getFault(); + Assert.assertEquals(new NotEnoughMemoryFault(4_048_090_666L, 1_000_000_000, 750_000_000, 2, 32), fault); } @Test public void test_oneWorkerInJvm_twoHundredWorkersInCluster() { - Assert.assertEquals(params(474_000_000, 1, 83, 150_000_000), create(2_000_000_000, 1, 1, 200, 0, 0)); - Assert.assertEquals(params(249_000_000, 2, 27, 150_000_000), create(2_000_000_000, 1, 2, 200, 0, 0)); + Assert.assertEquals(params(474_000_000, 1, 83, 150_000_000), create(2_000_000_000, 1, 1, 1, 200, 0, 0)); + Assert.assertEquals(params(249_000_000, 2, 27, 150_000_000), create(2_000_000_000, 1, 2, 1, 200, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(1_000_000_000, 1, 4, 200, 0, 0) + () -> create(1_000_000_000, 1, 4, 1, 200, 0, 0) ); Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault()); @@ -68,50 +88,76 @@ public void test_oneWorkerInJvm_twoHundredWorkersInCluster() @Test public void test_fourWorkersInJvm_twoHundredWorkersInCluster() { - Assert.assertEquals(params(1_014_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 200, 0, 0)); - Assert.assertEquals(params(811_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 200, 0, 0)); - Assert.assertEquals(params(558_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 200, 0, 0)); - Assert.assertEquals(params(305_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 200, 0, 0)); - Assert.assertEquals(params(102_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 200, 0, 0)); + Assert.assertEquals(params(1_014_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 1, 200, 0, 0)); + Assert.assertEquals(params(811_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 1, 200, 0, 0)); + Assert.assertEquals(params(558_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 1, 200, 0, 0)); + Assert.assertEquals(params(305_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 1, 200, 0, 0)); + Assert.assertEquals(params(102_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 1, 200, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(8_000_000_000L, 4, 32, 200, 0, 0) + () -> create(8_000_000_000L, 4, 32, 1, 200, 0, 0) ); Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault()); // Make sure 124 actually works, and 125 doesn't. (Verify the error message above.) - Assert.assertEquals(params(25_000_000, 4, 3, 150_000_000), create(8_000_000_000L, 4, 32, 124, 0, 0)); + Assert.assertEquals(params(25_000_000, 4, 3, 150_000_000), create(8_000_000_000L, 4, 32, 1, 124, 0, 0)); final MSQException e2 = Assert.assertThrows( MSQException.class, - () -> create(8_000_000_000L, 4, 32, 125, 0, 0) + () -> create(8_000_000_000L, 4, 32, 1, 125, 0, 0) ); Assert.assertEquals(new TooManyWorkersFault(125, 124), e2.getFault()); } + @Test + public void test_fourWorkersInJvm_twoHundredWorkersInCluster_twoConcurrentStages() + { + Assert.assertEquals(params(406_500_000, 1, 74, 84_375_000), create(9_000_000_000L, 4, 1, 2, 200, 0, 0)); + Assert.assertEquals(params(305_250_000, 2, 30, 84_375_000), create(9_000_000_000L, 4, 2, 2, 200, 0, 0)); + Assert.assertEquals(params(178_687_500, 4, 10, 84_375_000), create(9_000_000_000L, 4, 4, 2, 200, 0, 0)); + Assert.assertEquals(params(52_125_000, 4, 6, 84_375_000), create(9_000_000_000L, 4, 8, 2, 200, 0, 0)); + + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> create(8_000_000_000L, 4, 16, 2, 200, 0, 0) + ); + + Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault()); + + // Make sure 109 actually works, and 110 doesn't. (Verify the error message above.) + Assert.assertEquals(params(25_000_000, 4, 3, 75_000_000), create(8_000_000_000L, 4, 16, 2, 109, 0, 0)); + + final MSQException e2 = Assert.assertThrows( + MSQException.class, + () -> create(8_000_000_000L, 4, 16, 2, 110, 0, 0) + ); + + Assert.assertEquals(new TooManyWorkersFault(110, 109), e2.getFault()); + } + @Test public void test_oneWorkerInJvm_smallWorkerCapacity() { // Supersorter max channels per processer are one less than they are usually to account for extra frames that are required while creating composing output channels - Assert.assertEquals(params(41_200_000, 1, 3, 9_600_000), create(128_000_000, 1, 1, 1, 0, 0)); - Assert.assertEquals(params(26_800_000, 1, 1, 9_600_000), create(128_000_000, 1, 2, 1, 0, 0)); + Assert.assertEquals(params(41_200_000, 1, 3, 9_600_000), create(128_000_000, 1, 1, 1, 1, 0, 0)); + Assert.assertEquals(params(26_800_000, 1, 1, 9_600_000), create(128_000_000, 1, 2, 1, 1, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(1_000_000_000, 1, 32, 1, 0, 0) + () -> create(1_000_000_000, 1, 32, 1, 1, 0, 0) ); Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); final MSQException e2 = Assert.assertThrows( MSQException.class, - () -> create(128_000_000, 1, 4, 1, 0, 0) + () -> create(128_000_000, 1, 4, 1, 1, 0, 0) ); Assert.assertEquals(new NotEnoughMemoryFault(580_006_666, 12_8000_000, 96_000_000, 1, 4), e2.getFault()); - final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 0, 0)) + final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 1, 0, 0)) .getFault(); Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); @@ -120,24 +166,24 @@ public void test_oneWorkerInJvm_smallWorkerCapacity() @Test public void test_fourWorkersInJvm_twoHundredWorkersInCluster_hashPartitions() { - Assert.assertEquals(params(814_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 200, 200, 0)); - Assert.assertEquals(params(611_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 200, 200, 0)); - Assert.assertEquals(params(358_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 200, 200, 0)); - Assert.assertEquals(params(105_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 200, 200, 0)); + Assert.assertEquals(params(814_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 1, 200, 200, 0)); + Assert.assertEquals(params(611_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 1, 200, 200, 0)); + Assert.assertEquals(params(358_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 1, 200, 200, 0)); + Assert.assertEquals(params(105_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 1, 200, 200, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(9_000_000_000L, 4, 16, 200, 200, 0) + () -> create(9_000_000_000L, 4, 16, 1, 200, 200, 0) ); Assert.assertEquals(new TooManyWorkersFault(200, 138), e.getFault()); // Make sure 138 actually works, and 139 doesn't. (Verify the error message above.) - Assert.assertEquals(params(26_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 138, 138, 0)); + Assert.assertEquals(params(26_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 1, 138, 138, 0)); final MSQException e2 = Assert.assertThrows( MSQException.class, - () -> create(9_000_000_000L, 4, 16, 139, 139, 0) + () -> create(9_000_000_000L, 4, 16, 1, 139, 139, 0) ); Assert.assertEquals(new TooManyWorkersFault(139, 138), e2.getFault()); @@ -148,7 +194,7 @@ public void test_oneWorkerInJvm_oneByteUsableMemory() { final MSQException e = Assert.assertThrows( MSQException.class, - () -> WorkerMemoryParameters.createInstance(1, 1, 1, 32, 1, 1) + () -> WorkerMemoryParameters.createInstance(1, 1, 1, 1, 32, 1, 1) ); Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1), e.getFault()); @@ -179,6 +225,7 @@ private static WorkerMemoryParameters create( final long maxMemoryInJvm, final int numWorkersInJvm, final int numProcessingThreadsInJvm, + final int maxConcurrentStages, final int numInputWorkers, final int numHashOutputPartitions, final int totalLookUpFootprint @@ -188,6 +235,7 @@ private static WorkerMemoryParameters create( maxMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm, + maxConcurrentStages, numInputWorkers, numHashOutputPartitions, totalLookUpFootprint diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java index 583c21d3407c..dfb88d17b216 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java @@ -19,6 +19,7 @@ package org.apache.druid.msq.indexing; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; import com.google.inject.Injector; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; @@ -30,6 +31,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import org.mockito.quality.Strictness; import java.util.Collections; @@ -44,12 +46,19 @@ public void setup() Mockito.when(injectorMock.getInstance(SegmentCacheManagerFactory.class)) .thenReturn(Mockito.mock(SegmentCacheManagerFactory.class)); + final MSQWorkerTask task = + Mockito.mock(MSQWorkerTask.class, Mockito.withSettings().strictness(Strictness.STRICT_STUBS)); + Mockito.when(task.getContext()).thenReturn(ImmutableMap.of()); + indexerWorkerContext = new IndexerWorkerContext( + task, Mockito.mock(TaskToolbox.class), injectorMock, null, null, null, + null, + null, null ); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java index 5d86abd129ce..9e61c9dd7b83 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java @@ -19,12 +19,8 @@ package org.apache.druid.msq.indexing; -import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.indexer.TaskStatus; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; -import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.exec.Worker; @@ -32,12 +28,9 @@ import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import org.apache.druid.segment.IndexIO; -import org.apache.druid.segment.IndexMergerV9; -import org.apache.druid.segment.column.ColumnConfig; -import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.sql.calcite.util.CalciteTests; import org.junit.After; import org.junit.Assert; @@ -51,15 +44,16 @@ import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.io.InputStream; -import java.util.HashMap; public class WorkerChatHandlerTest { private static final StageId TEST_STAGE = new StageId("123", 0); + private static final String DATASOURCE = "foo"; + @Mock private HttpServletRequest req; - private TaskToolbox toolbox; + private AuthorizerMapper authorizerMapper; private AutoCloseable mocks; private final TestWorker worker = new TestWorker(); @@ -67,29 +61,16 @@ public class WorkerChatHandlerTest @Before public void setUp() { - ObjectMapper mapper = new DefaultObjectMapper(); - IndexIO indexIO = new IndexIO(mapper, ColumnConfig.DEFAULT); - IndexMergerV9 indexMerger = new IndexMergerV9( - mapper, - indexIO, - OffHeapMemorySegmentWriteOutMediumFactory.instance() - ); - + authorizerMapper = CalciteTests.TEST_AUTHORIZER_MAPPER; mocks = MockitoAnnotations.openMocks(this); Mockito.when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) .thenReturn(new AuthenticationResult("druid", "druid", null, null)); - TaskToolbox.Builder builder = new TaskToolbox.Builder(); - toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER) - .indexIO(indexIO) - .indexMergerV9(indexMerger) - .taskReportFileWriter(new NoopTestTaskReportFileWriter()) - .build(); } @Test public void testFetchSnapshot() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( ClusterByStatisticsSnapshot.empty(), chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), null, req) @@ -100,7 +81,7 @@ public void testFetchSnapshot() @Test public void testFetchSnapshot404() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( Response.Status.BAD_REQUEST.getStatusCode(), chatHandler.httpFetchKeyStatistics("123", 2, null, req) @@ -111,7 +92,7 @@ public void testFetchSnapshot404() @Test public void testFetchSnapshotWithTimeChunk() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( ClusterByStatisticsSnapshot.empty(), chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), 1, null, req) @@ -122,7 +103,7 @@ public void testFetchSnapshotWithTimeChunk() @Test public void testFetchSnapshotWithTimeChunk404() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( Response.Status.BAD_REQUEST.getStatusCode(), chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, null, req) @@ -133,7 +114,6 @@ public void testFetchSnapshotWithTimeChunk404() private static class TestWorker implements Worker { - @Override public String id() { @@ -141,25 +121,25 @@ public String id() } @Override - public MSQWorkerTask task() + public void run() { - return new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0); + } @Override - public TaskStatus run() + public void stop() { - return null; + } @Override - public void stopGracefully() + public void controllerFailed() { } @Override - public void controllerFailed() + public void awaitStop() { } @@ -192,9 +172,8 @@ public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId s @Override public boolean postResultPartitionBoundaries( - ClusterByPartitions stagePartitionBoundaries, - String queryId, - int stageNumber + StageId stageId, + ClusterByPartitions stagePartitionBoundaries ) { return false; @@ -202,7 +181,7 @@ public boolean postResultPartitionBoundaries( @Nullable @Override - public InputStream readChannel(String queryId, int stageNumber, int partitionNumber, long offset) + public ListenableFuture readChannel(StageId stageId, int partitionNumber, long offset) { return null; } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java new file mode 100644 index 000000000000..bc349d56c8fd --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.collect.ImmutableList; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +public class ByteChunksInputStreamTest +{ + private final List chunks = ImmutableList.of( + new byte[]{-128, -127, -1, 0, 1, 126, 127}, + new byte[]{0}, + new byte[]{3, 4, 5} + ); + + @Test + public void test_read_fromStart() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 0)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + int c; + while ((c = in.read()) != -1) { + MatcherAssert.assertThat("InputStream#read contract", c, Matchers.greaterThanOrEqualTo(0)); + baos.write(c); + } + + Assert.assertArrayEquals(chunksSubset(0), baos.toByteArray()); + } + } + + @Test + public void test_read_fromSecondByte() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 1)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + int c; + while ((c = in.read()) != -1) { + MatcherAssert.assertThat("InputStream#read contract", c, Matchers.greaterThanOrEqualTo(0)); + baos.write(c); + } + + Assert.assertArrayEquals(chunksSubset(1), baos.toByteArray()); + } + } + + @Test + public void test_read_array1_fromStart() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 0)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[2]; + + int r; + while ((r = in.read(buf, 1, 1)) != -1) { + Assert.assertEquals("InputStream#read bytes read", 1, r); + baos.write(buf, 1, 1); + } + + Assert.assertArrayEquals(chunksSubset(0), baos.toByteArray()); + } + } + + @Test + public void test_read_array1_fromSecondByte() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 1)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[2]; + + int r; + while ((r = in.read(buf, 1, 1)) != -1) { + Assert.assertEquals("InputStream#read bytes read", 1, r); + baos.write(buf, 1, 1); + } + + Assert.assertArrayEquals(chunksSubset(1), baos.toByteArray()); + } + } + + @Test + public void test_read_array3_fromStart() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 0)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[5]; + + int r; + while ((r = in.read(buf, 2, 3)) != -1) { + baos.write(buf, 2, r); + } + + Assert.assertArrayEquals(chunksSubset(0), baos.toByteArray()); + } + } + + @Test + public void test_read_array3_fromSecondByte() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 1)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[6]; + + int r; + while ((r = in.read(buf, 2, 3)) != -1) { + baos.write(buf, 2, r); + } + + Assert.assertArrayEquals(chunksSubset(1), baos.toByteArray()); + } + } + + private byte[] chunksSubset(final int positionInFirstChunk) + { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + for (int chunk = 0, p = positionInFirstChunk; chunk < chunks.size(); chunk++, p = 0) { + baos.write(chunks.get(chunk), p, chunks.get(chunk).length - p); + } + + return baos.toByteArray(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java index b60c6c71d2e2..124b4fce2588 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java @@ -64,15 +64,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java index 2d8067e900e9..5d4c0994ea06 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java @@ -67,15 +67,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java index 317fe30a646d..6bbf9c6da5e4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java @@ -136,15 +136,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java index 3008f9d43b47..2de9229b4adc 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java @@ -73,15 +73,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java index b5d8368b068f..e4b678402a8b 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java @@ -79,15 +79,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 5f0bd545b7c6..2136d96d6d11 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -333,16 +333,7 @@ public class MSQTestBase extends BaseCalciteQueryTest private SegmentCacheManager segmentCacheManager; private TestGroupByBuffers groupByBuffers; - protected final WorkerMemoryParameters workerMemoryParameters = Mockito.spy( - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 1, - 0 - ) - ); + protected final WorkerMemoryParameters workerMemoryParameters = Mockito.spy(makeTestWorkerMemoryParameters()); protected static class MSQBaseComponentSupplier extends StandardComponentSupplier { @@ -753,6 +744,19 @@ public static ObjectMapper setupObjectMapper(Injector injector) return mapper; } + public static WorkerMemoryParameters makeTestWorkerMemoryParameters() + { + return WorkerMemoryParameters.createInstance( + WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, + 2, + 10, + 1, + 2, + 1, + 0 + ); + } + private String runMultiStageQuery(String query, Map context) { final DirectStatement stmt = sqlStatementFactory.directStatement( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java index 96e26cba77e1..1bfb7177f9dd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java @@ -87,7 +87,7 @@ public void postWorkerWarning(List MSQErrorReports) } @Override - public List getTaskList() + public List getWorkerIds() { return controller.getTaskIds(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 20d31fbd4cfe..e65104302032 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -156,32 +156,33 @@ public ListenableFuture runTask(String taskId, Object taskObject) Worker worker = new WorkerImpl( task, new MSQTestWorkerContext( + task.getId(), inMemoryWorkers, controller, mapper, injector, - workerMemoryParameters - ), - workerStorageParameters + workerMemoryParameters, + workerStorageParameters + ) ); inMemoryWorkers.put(task.getId(), worker); statusMap.put(task.getId(), TaskStatus.running(task.getId())); - ListenableFuture future = executor.submit(() -> { + ListenableFuture future = executor.submit(() -> { try { - return worker.run(); + worker.run(); } catch (Exception e) { throw new RuntimeException(e); } }); - Futures.addCallback(future, new FutureCallback() + Futures.addCallback(future, new FutureCallback() { @Override - public void onSuccess(@Nullable TaskStatus result) + public void onSuccess(@Nullable Object result) { - statusMap.put(task.getId(), result); + statusMap.put(task.getId(), TaskStatus.success(task.getId())); } @Override @@ -261,7 +262,7 @@ public ListenableFuture cancelTask(String workerId) { final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { - worker.stopGracefully(); + worker.stop(); } return Futures.immediateFuture(null); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index 72cb246a43e1..2459a83ecfe1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -80,11 +80,7 @@ public ListenableFuture postResultPartitionBoundaries( ) { try { - inMemoryWorkers.get(workerTaskId).postResultPartitionBoundaries( - partitionBoundaries, - stageId.getQueryId(), - stageId.getStageNumber() - ); + inMemoryWorkers.get(workerTaskId).postResultPartitionBoundaries(stageId, partitionBoundaries); return Futures.immediateFuture(null); } catch (Exception e) { @@ -122,8 +118,7 @@ public ListenableFuture fetchChannelData( ) { try (InputStream inputStream = - inMemoryWorkers.get(workerTaskId) - .readChannel(stageId.getQueryId(), stageId.getStageNumber(), partitionNumber, offset)) { + inMemoryWorkers.get(workerTaskId).readChannel(stageId, partitionNumber, offset).get()) { byte[] buffer = new byte[8 * 1024]; boolean didRead = false; int bytesRead; @@ -138,12 +133,11 @@ public ListenableFuture fetchChannelData( catch (Exception e) { throw new ISE(e, "Error reading frame file channel"); } - } @Override public void close() { - inMemoryWorkers.forEach((k, v) -> v.stopGracefully()); + inMemoryWorkers.forEach((k, v) -> v.stop()); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index 14f6f73b24ab..4d309db7a81c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -22,59 +22,70 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.indexer.report.TaskReportFileWriter; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerClient; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; -import org.apache.druid.msq.indexing.IndexerFrameContext; -import org.apache.druid.msq.indexing.IndexerWorkerContext; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; +import org.apache.druid.segment.SegmentWrangler; import org.apache.druid.segment.column.ColumnConfig; import org.apache.druid.segment.incremental.NoopRowIngestionMeters; +import org.apache.druid.segment.incremental.RowIngestionMeters; import org.apache.druid.segment.loading.DataSegmentPusher; -import org.apache.druid.segment.realtime.NoopChatHandlerProvider; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.server.DruidNode; -import org.apache.druid.server.coordination.DataSegmentAnnouncer; -import org.apache.druid.server.security.AuthTestUtils; import java.io.File; +import java.io.IOException; import java.util.Map; public class MSQTestWorkerContext implements WorkerContext { + private final String workerId; private final Controller controller; private final ObjectMapper mapper; private final Injector injector; private final Map inMemoryWorkers; private final File file = FileUtils.createTempDir(); + private final Bouncer bouncer = new Bouncer(1); private final WorkerMemoryParameters workerMemoryParameters; + private final WorkerStorageParameters workerStorageParameters; public MSQTestWorkerContext( + String workerId, Map inMemoryWorkers, Controller controller, ObjectMapper mapper, Injector injector, - WorkerMemoryParameters workerMemoryParameters + WorkerMemoryParameters workerMemoryParameters, + WorkerStorageParameters workerStorageParameters ) { + this.workerId = workerId; this.inMemoryWorkers = inMemoryWorkers; this.controller = controller; this.mapper = mapper; this.injector = injector; this.workerMemoryParameters = workerMemoryParameters; + this.workerStorageParameters = workerStorageParameters; + } + + @Override + public String queryId() + { + return controller.queryId(); } @Override @@ -96,7 +107,13 @@ public void registerWorker(Worker worker, Closer closer) } @Override - public ControllerClient makeControllerClient(String controllerId) + public String workerId() + { + return workerId; + } + + @Override + public ControllerClient makeControllerClient() { return new MSQTestControllerClient(controller); } @@ -114,42 +131,9 @@ public File tempDir() } @Override - public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) + public FrameContext frameContext(QueryDefinition queryDef, int stageNumber, OutputChannelMode outputChannelMode) { - IndexIO indexIO = new IndexIO(mapper, ColumnConfig.DEFAULT); - IndexMergerV9 indexMerger = new IndexMergerV9( - mapper, - indexIO, - OffHeapMemorySegmentWriteOutMediumFactory.instance(), - true - ); - final TaskReportFileWriter reportFileWriter = new NoopTestTaskReportFileWriter(); - - return new IndexerFrameContext( - new IndexerWorkerContext( - new TaskToolbox.Builder() - .segmentPusher(injector.getInstance(DataSegmentPusher.class)) - .segmentAnnouncer(injector.getInstance(DataSegmentAnnouncer.class)) - .jsonMapper(mapper) - .taskWorkDir(tempDir()) - .indexIO(indexIO) - .indexMergerV9(indexMerger) - .taskReportFileWriter(reportFileWriter) - .authorizerMapper(AuthTestUtils.TEST_AUTHORIZER_MAPPER) - .chatHandlerProvider(new NoopChatHandlerProvider()) - .rowIngestionMetersFactory(NoopRowIngestionMeters::new) - .build(), - injector, - indexIO, - null, - null, - null - ), - indexIO, - injector.getInstance(DataSegmentProvider.class), - injector.getInstance(DataServerQueryHandlerFactory.class), - workerMemoryParameters - ); + return new FrameContextImpl(new File(tempDir(), queryDef.getStageDefinition(stageNumber).getId().toString())); } @Override @@ -165,9 +149,9 @@ public DruidNode selfNode() } @Override - public Bouncer processorBouncer() + public int maxConcurrentStages() { - return injector.getInstance(Bouncer.class); + return 1; } @Override @@ -175,4 +159,109 @@ public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() { return injector.getInstance(DataServerQueryHandlerFactory.class); } + + class FrameContextImpl implements FrameContext + { + private final File tempDir; + + public FrameContextImpl(File tempDir) + { + this.tempDir = tempDir; + } + + @Override + public SegmentWrangler segmentWrangler() + { + return injector.getInstance(SegmentWrangler.class); + } + + @Override + public GroupingEngine groupingEngine() + { + return injector.getInstance(GroupingEngine.class); + } + + @Override + public RowIngestionMeters rowIngestionMeters() + { + return new NoopRowIngestionMeters(); + } + + @Override + public DataSegmentProvider dataSegmentProvider() + { + return injector.getInstance(DataSegmentProvider.class); + } + + @Override + public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() + { + return injector.getInstance(DataServerQueryHandlerFactory.class); + } + + @Override + public File tempDir() + { + return new File(tempDir, "tmp"); + } + + @Override + public ObjectMapper jsonMapper() + { + return mapper; + } + + @Override + public IndexIO indexIO() + { + return new IndexIO(mapper, ColumnConfig.DEFAULT); + } + + @Override + public File persistDir() + { + return new File(tempDir, "persist"); + } + + @Override + public DataSegmentPusher segmentPusher() + { + return injector.getInstance(DataSegmentPusher.class); + } + + @Override + public IndexMergerV9 indexMerger() + { + return new IndexMergerV9( + mapper, + indexIO(), + OffHeapMemorySegmentWriteOutMediumFactory.instance(), + true + ); + } + + @Override + public Bouncer processorBouncer() + { + return bouncer; + } + + @Override + public WorkerMemoryParameters memoryParameters() + { + return workerMemoryParameters; + } + + @Override + public WorkerStorageParameters storageParameters() + { + return workerStorageParameters; + } + + @Override + public void close() throws IOException + { + + } + } } diff --git a/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java b/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java index 963a001ad6db..7da6550ccca7 100644 --- a/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java +++ b/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java @@ -104,6 +104,14 @@ public void close() } } + /** + * Returns whether this channel represents the entire underlying {@link FrameFile}. + */ + public boolean isEntireFile() + { + return currentFrame == 0 && endFrame == frameFile.numFrames(); + } + /** * Returns a new reference to the {@link FrameFile} that this channel is reading from. Callers should close this * reference when done reading. From b6016a2ca361bc4ec8e85ead9f5f93ac9e6d7653 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 24 Jul 2024 11:32:35 -0700 Subject: [PATCH 02/13] Updates for static checks, test coverage. --- .../msq/exec/WorkerMemoryParameters.java | 3 +- .../msq/indexing/IndexerFrameContext.java | 3 +- .../apache/druid/msq/kernel/WorkOrder.java | 2 +- .../shuffle/output/ByteChunksInputStream.java | 9 +- .../output/ChannelStageOutputReader.java | 7 +- .../shuffle/output/NilStageOutputReader.java | 2 +- .../druid/msq/test/MSQTestWorkerContext.java | 3 +- .../ReadableFileFrameChannelTest.java | 104 ++++++++++++++++++ 8 files changed, 118 insertions(+), 15 deletions(-) create mode 100644 processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index 8d3f15c09c79..14d30f666531 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -524,7 +524,8 @@ private static long memoryPerWorker( } /** - * Compute the memory allocated to each processing bundle. Any computation changes done to this method should also be done in its corresponding method {@link WorkerMemoryParameters#estimateUsableMemory(int, int, long)} + * Compute the memory allocated to each processing bundle. Any computation changes done to this method should also be + * done in its corresponding method {@link WorkerMemoryParameters#estimateUsableMemory} */ private static long memoryPerBundle( final long usableMemoryInJvm, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index 240400aa6d5e..fb6e4a0079f1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -36,7 +36,6 @@ import org.apache.druid.segment.loading.DataSegmentPusher; import java.io.File; -import java.io.IOException; public class IndexerFrameContext implements FrameContext { @@ -154,7 +153,7 @@ public WorkerStorageParameters storageParameters() } @Override - public void close() throws IOException + public void close() { // Nothing to close. } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java index 201a1783c05f..0c8578702103 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java @@ -109,7 +109,7 @@ ExtraInfoHolder getExtraInfoHolder() /** * Worker IDs for this query, if known in advance (at the time the work order is created). May be null, in which - * case workers use {@link ControllerClient#getTaskList()} to find worker IDs. + * case workers use {@link ControllerClient#getWorkerIds()} to find worker IDs. */ @Nullable @JsonProperty("workers") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java index 4767d818dea4..d475bfd03055 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java @@ -19,7 +19,6 @@ package org.apache.druid.msq.shuffle.output; -import java.io.IOException; import java.io.InputStream; import java.util.List; @@ -45,7 +44,7 @@ public ByteChunksInputStream(final List chunks, final int positionWithin } @Override - public int read() throws IOException + public int read() { if (chunkNum >= chunks.size()) { return -1; @@ -63,13 +62,13 @@ public int read() throws IOException } @Override - public int read(byte[] b) throws IOException + public int read(byte[] b) { return read(b, 0, b.length); } @Override - public int read(byte[] b, int off, int len) throws IOException + public int read(byte[] b, int off, int len) { if (len == 0) { return 0; @@ -95,7 +94,7 @@ public int read(byte[] b, int off, int len) throws IOException } @Override - public void close() throws IOException + public void close() { chunkNum = chunks.size(); positionWithinChunk = 0; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java index eba835cd544b..b9cc0caae47c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -38,7 +38,8 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; -import java.util.LinkedList; +import java.util.ArrayDeque; +import java.util.Deque; /** * Reader for {@link ReadableFrameChannel}. @@ -64,7 +65,7 @@ enum State /** * Pair of chunk size + chunk InputStream. */ - private final LinkedList chunks = new LinkedList<>(); + private final Deque chunks = new ArrayDeque<>(); /** * State of this reader. @@ -190,7 +191,7 @@ public synchronized ReadableFrameChannel readLocally() } @Override - public synchronized void close() throws IOException + public synchronized void close() { // Call channel.close() unless readLocally() has been called. In that case, we expect the caller to close it. if (state != State.LOCAL) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java index 3841cc7d7aee..e12341f99f83 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java @@ -70,7 +70,7 @@ public ReadableFrameChannel readLocally() } @Override - public void close() throws IOException + public void close() { // Nothing to do. } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index 4d309db7a81c..082429a9d7b1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -48,7 +48,6 @@ import org.apache.druid.server.DruidNode; import java.io.File; -import java.io.IOException; import java.util.Map; public class MSQTestWorkerContext implements WorkerContext @@ -259,7 +258,7 @@ public WorkerStorageParameters storageParameters() } @Override - public void close() throws IOException + public void close() { } diff --git a/processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java b/processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java new file mode 100644 index 000000000000..9025d1820864 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.frame.processor; + +import org.apache.druid.frame.FrameType; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.testutil.FrameSequenceBuilder; +import org.apache.druid.frame.testutil.FrameTestUtil; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.segment.QueryableIndexStorageAdapter; +import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.TestIndex; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +public class ReadableFileFrameChannelTest extends InitializedNullHandlingTest +{ + private static final int ROWS_PER_FRAME = 20; + + private List> allRows; + private FrameReader frameReader; + private FrameFile frameFile; + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Before + public void setUp() throws IOException + { + final StorageAdapter adapter = new QueryableIndexStorageAdapter(TestIndex.getNoRollupMMappedTestIndex()); + final File file = FrameTestUtil.writeFrameFile( + FrameSequenceBuilder.fromAdapter(adapter) + .frameType(FrameType.ROW_BASED) + .maxRowsPerFrame(ROWS_PER_FRAME) + .frames(), + temporaryFolder.newFile() + ); + allRows = FrameTestUtil.readRowsFromAdapter(adapter, adapter.getRowSignature(), false).toList(); + frameReader = FrameReader.create(adapter.getRowSignature()); + frameFile = FrameFile.open(file, null, FrameFile.Flag.DELETE_ON_CLOSE); + } + + @After + public void tearDown() throws Exception + { + frameFile.close(); + } + + @Test + public void test_fullFile() + { + final ReadableFileFrameChannel channel = new ReadableFileFrameChannel(frameFile); + Assert.assertTrue(channel.isEntireFile()); + + FrameTestUtil.assertRowsEqual( + Sequences.simple(allRows), + FrameTestUtil.readRowsFromFrameChannel(channel, frameReader) + ); + + Assert.assertFalse(channel.isEntireFile()); + } + + @Test + public void test_partialFile() + { + final ReadableFileFrameChannel channel = new ReadableFileFrameChannel(frameFile, 1, 2); + Assert.assertFalse(channel.isEntireFile()); + + FrameTestUtil.assertRowsEqual( + Sequences.simple(allRows).skip(ROWS_PER_FRAME).limit(ROWS_PER_FRAME), + FrameTestUtil.readRowsFromFrameChannel(channel, frameReader) + ); + + Assert.assertFalse(channel.isEntireFile()); + } +} From b1445fff796d3faddd743aee0b49f5f5defc615c Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 24 Jul 2024 12:07:59 -0700 Subject: [PATCH 03/13] Fixes. --- .../src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java | 1 - .../src/main/java/org/apache/druid/msq/exec/Worker.java | 2 +- .../main/java/org/apache/druid/msq/rpc/ControllerResource.java | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java index 45689652a646..19c866e164c3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java @@ -488,7 +488,6 @@ private void writeDurableStorageSuccessFile() throw new ISE( e, "Unable to create success file at location[%s]", - DurableStorageUtils.SUCCESS_MARKER_FILENAME, durableStorageOutputChannelFactory.getSuccessFilePath() ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java index b068796cec70..9277bb5beaa1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java @@ -107,7 +107,7 @@ public interface Worker * * @throws IOException when the worker output is found but there is an error while reading it. */ - ListenableFuture readChannel(StageId stageId, int partitionNumber, long offset) throws IOException; + ListenableFuture readChannel(StageId stageId, int partitionNumber, long offset); /** * Returns a snapshot of counters. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java index c6ddb5cd582b..18049745d71f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java @@ -93,7 +93,6 @@ public Response httpPostPartialKeyStatistics( @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) public Response httpPostDoneReadingInput( - @PathParam("queryId") final String queryId, @PathParam("stageNumber") final int stageNumber, @PathParam("workerNumber") final int workerNumber, @Context final HttpServletRequest req From b22b30ed944c784b6deb64e14f8a4cfec4b2b044 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 24 Jul 2024 13:26:00 -0700 Subject: [PATCH 04/13] Remove exception. --- .../src/main/java/org/apache/druid/msq/rpc/WorkerResource.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java index 88dfddaeb7ce..00a99c524493 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -54,7 +54,6 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.StreamingOutput; -import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -97,7 +96,7 @@ public Response httpGetChannelData( @PathParam("partitionNumber") final int partitionNumber, @QueryParam("offset") final long offset, @Context final HttpServletRequest req - ) throws IOException + ) { MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); From 4d8d1a0080937bc0aa452b13e3556023673f0138 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Fri, 26 Jul 2024 10:44:18 -0700 Subject: [PATCH 05/13] Changes from review. --- .../org/apache/druid/msq/exec/Controller.java | 4 +- .../druid/msq/exec/ControllerClient.java | 13 +++++-- .../apache/druid/msq/exec/ControllerImpl.java | 10 ++--- .../org/apache/druid/msq/exec/WorkerImpl.java | 2 +- .../druid/msq/rpc/ControllerResource.java | 2 +- .../shuffle/output/ByteChunksInputStream.java | 14 +++++++ .../output/ChannelStageOutputReader.java | 5 +++ .../output/FutureReadableFrameChannel.java | 39 +++++++++---------- .../msq/test/MSQTestControllerClient.java | 2 +- 9 files changed, 57 insertions(+), 34 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java index f04286dd7c42..d2370b057935 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java @@ -117,9 +117,9 @@ void resultsComplete( ); /** - * Returns the current list of task ids, ordered by worker number. The Nth task has worker number N. + * Returns the current list of worker IDs, ordered by worker number. The Nth worker has worker number N. */ - List getTaskIds(); + List getWorkerIds(); @Nullable TaskReport.ReportMap liveReports(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java index cbc3544c93ae..428ce59cd8fa 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java @@ -23,6 +23,7 @@ import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; @@ -81,12 +82,9 @@ void postResultsComplete( /** * Client side method to inform the controller that the error has occured in the given worker. - * - * @param queryId query ID, if this error is associated with a specific query - * @param errorWrapper error details */ void postWorkerError( - @Nullable String queryId, + String workerId, MSQErrorReport errorWrapper ) throws IOException; @@ -95,6 +93,13 @@ void postWorkerError( */ void postWorkerWarning(List MSQErrorReports) throws IOException; + /** + * Client side method for retrieving the list of worker IDs from the controller. These IDs can be passed to + * {@link WorkerClient} methods to communicate with other workers. Not necessary when the {@link WorkOrder} has + * {@link WorkOrder#getWorkerIds()} set. + * + * @see Controller#getWorkerIds() for the controller side + */ List getWorkerIds() throws IOException; /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index a30e96860875..bc5dd5d2ac3a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -1172,7 +1172,7 @@ private List generateSegmentIdsWithShardSpecsForReplace( } @Override - public List getTaskIds() + public List getWorkerIds() { if (workerManager == null) { return Collections.emptyList(); @@ -1261,7 +1261,7 @@ private void contactWorkersForStage( { // Sorted copy of target worker numbers to ensure consistent iteration order. final List workersCopy = Ordering.natural().sortedCopy(workers); - final List workerIds = getTaskIds(); + final List workerIds = getWorkerIds(); final List> workerFutures = new ArrayList<>(workersCopy.size()); try { @@ -1486,7 +1486,7 @@ private List findIntervalsToDrop(final Set publishedSegme private CounterSnapshotsTree getCountersFromAllTasks() { final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); - final List taskList = getTaskIds(); + final List taskList = getWorkerIds(); final List> futures = new ArrayList<>(); @@ -1506,7 +1506,7 @@ private CounterSnapshotsTree getCountersFromAllTasks() private void postFinishToAllTasks() { - final List taskList = getTaskIds(); + final List taskList = getWorkerIds(); final List> futures = new ArrayList<>(); @@ -2930,7 +2930,7 @@ private void startQueryResultsReader() } final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); - final List taskIds = getTaskIds(); + final List taskIds = getWorkerIds(); final InputChannelFactory inputChannelFactory; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 91e003361628..9cc9e3a32a4f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -193,7 +193,7 @@ public void run() log.warn("%s", logMessage); if (controllerAlive) { - controllerClient.postWorkerError(context.queryId(), errorReport); + controllerClient.postWorkerError(context.workerId(), errorReport); } if (t != null) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java index 18049745d71f..cc570ec992ad 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java @@ -195,7 +195,7 @@ public Response httpPostResultsComplete( public Response httpGetTaskList(@Context final HttpServletRequest req) { MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); - return Response.ok(new MSQTaskList(controller.getTaskIds())).build(); + return Response.ok(new MSQTaskList(controller.getWorkerIds())).build(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java index d475bfd03055..f623e58f65b3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java @@ -19,6 +19,8 @@ package org.apache.druid.msq.shuffle.output; +import org.apache.druid.error.DruidException; + import java.io.InputStream; import java.util.List; @@ -41,6 +43,8 @@ public ByteChunksInputStream(final List chunks, final int positionWithin { this.chunks = chunks; this.positionWithinChunk = positionWithinFirstChunk; + this.chunkNum = -1; + advanceChunk(); } @Override @@ -99,4 +103,14 @@ public void close() chunkNum = chunks.size(); positionWithinChunk = 0; } + + private void advanceChunk() + { + chunkNum++; + + // Verify nonempty + if (chunkNum < chunks.size() && chunks.get(chunkNum).length == 0) { + throw DruidException.defensive("Empty chunk not allowed"); + } + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java index b9cc0caae47c..d81de9438358 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -23,6 +23,7 @@ import com.google.common.primitives.Ints; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import it.unimi.dsi.fastutil.bytes.ByteArrays; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.error.DruidException; @@ -70,21 +71,25 @@ enum State /** * State of this reader. */ + @GuardedBy("this") private State state = State.INIT; /** * Position within the overall stream. */ + @GuardedBy("this") private long cursor; /** * Offset of the first chunk in {@link #chunks} which corresponds to {@link #cursor}. */ + @GuardedBy("this") private int positionWithinFirstChunk; /** * Whether {@link FrameFileWriter#close()} is called on {@link #writer}. */ + @GuardedBy("this") private boolean didCloseWriter; public ChannelStageOutputReader(final ReadableFrameChannel channel) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java index 37500ae5eafd..02f6efdc12bd 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java @@ -23,11 +23,14 @@ import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.frame.Frame; import org.apache.druid.frame.channel.ReadableFrameChannel; -import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.logger.Logger; import java.util.NoSuchElementException; +/** + * Channel that wraps a {@link ListenableFuture} of a {@link ReadableFrameChannel}, but acts like a regular (non-future) + * {@link ReadableFrameChannel}. + */ public class FutureReadableFrameChannel implements ReadableFrameChannel { private static final Logger log = new Logger(FutureReadableFrameChannel.class); @@ -87,27 +90,23 @@ public void close() channel.close(); } else { channelFuture.cancel(true); - channelFuture.addListener( - () -> { - final ReadableFrameChannel channel; - try { - channel = FutureUtils.getUncheckedImmediately(channelFuture); - } - catch (Throwable ignored) { - // Some error happened while creating the channel. Suppress it. - return; - } + // In case of a race where channelFuture resolved between populateChannel() and here, the cancel call above would + // have no effect. Guard against this case by checking if the channelFuture has resolved, and if so, close the + // channel here. + try { + final ReadableFrameChannel channel = FutureUtils.getUncheckedImmediately(channelFuture); - try { - channel.close(); - } - catch (Throwable t) { - log.noStackTrace().warn(t, "Failed to close channel"); - } - }, - Execs.directExecutor() - ); + try { + channel.close(); + } + catch (Throwable t) { + log.noStackTrace().warn(t, "Failed to close channel"); + } + } + catch (Throwable ignored) { + // Some error happened while creating the channel. Suppress it. + } } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java index 1bfb7177f9dd..4c7ca61be023 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java @@ -89,7 +89,7 @@ public void postWorkerWarning(List MSQErrorReports) @Override public List getWorkerIds() { - return controller.getTaskIds(); + return controller.getWorkerIds(); } @Override From aa3d3bf30193fbb7b82a652fbcc1b02d09b312a6 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Fri, 26 Jul 2024 12:41:27 -0700 Subject: [PATCH 06/13] Address static check. --- .../msq/shuffle/output/FutureReadableFrameChannel.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java index 02f6efdc12bd..8dcb8786713b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java @@ -95,17 +95,17 @@ public void close() // have no effect. Guard against this case by checking if the channelFuture has resolved, and if so, close the // channel here. try { - final ReadableFrameChannel channel = FutureUtils.getUncheckedImmediately(channelFuture); + final ReadableFrameChannel theChannel = FutureUtils.getUncheckedImmediately(channelFuture); try { - channel.close(); + theChannel.close(); } catch (Throwable t) { log.noStackTrace().warn(t, "Failed to close channel"); } } catch (Throwable ignored) { - // Some error happened while creating the channel. Suppress it. + // Suppress. } } } From 11ea1aa67a04b66a2c1f69f69d505dddd26789ab Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 29 Jul 2024 11:43:56 -0700 Subject: [PATCH 07/13] Changes from review. --- .../msq/exec/ControllerMemoryParameters.java | 3 +- .../druid/msq/exec/OutputChannelMode.java | 9 +++-- .../apache/druid/msq/exec/RunWorkOrder.java | 7 ++++ .../msq/exec/WorkerMemoryParameters.java | 11 ++++-- .../indexing/error/NotEnoughMemoryFault.java | 38 +++++++++++++------ .../msq/kernel/worker/WorkerStagePhase.java | 2 + .../output/ChannelStageOutputReader.java | 1 + .../msq/exec/WorkerMemoryParametersTest.java | 16 ++++---- .../msq/indexing/error/MSQFaultSerdeTest.java | 2 +- .../msq/test/MSQTestControllerContext.java | 2 +- .../druid/msq/test/MSQTestWorkerClient.java | 2 +- 11 files changed, 64 insertions(+), 29 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java index 8e6fc72b6aa7..2ab016e10e48 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java @@ -91,7 +91,8 @@ public static ControllerMemoryParameters createProductionInstance( memoryIntrospector.totalMemoryInJvm(), usableMemoryInJvm, numControllersInJvm, - memoryIntrospector.numProcessorsInJvm() + memoryIntrospector.numProcessorsInJvm(), + 0 ) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java index 7e7fc3d3d6f3..f42d558a76ce 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java @@ -32,9 +32,12 @@ public enum OutputChannelMode { /** - * In-memory output channels. Stage shuffle data does not hit disk. This mode requires a consumer stage to run - * at the same time as its corresponding producer stage. See {@link ControllerQueryKernelUtils#computeStageGroups} for the - * logic that determines when we can use in-memory channels. + * In-memory output channels. Stage shuffle data does not hit disk. In-memory channels do not fully buffer stage + * output. They use a blocking queue; see {@link RunWorkOrder#makeStageOutputChannelFactory()}. + * + * Because stage output is not fully buffered, this mode requires a consumer stage to run at the same time as its + * corresponding producer stage. See {@link ControllerQueryKernelUtils#computeStageGroups} for the logic that + * determines when we can use in-memory channels. */ MEMORY("memory"), diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java index 19c866e164c3..0173979efeed 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java @@ -93,6 +93,7 @@ import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import org.apache.druid.utils.CloseableUtils; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import javax.annotation.Nullable; import java.io.File; @@ -123,11 +124,17 @@ public class RunWorkOrder private final ByteTracker intermediateSuperSorterLocalStorageTracker; private final AtomicBoolean started = new AtomicBoolean(); + @MonotonicNonNull private InputSliceReader inputSliceReader; + @MonotonicNonNull private OutputChannelFactory workOutputChannelFactory; + @MonotonicNonNull private OutputChannelFactory shuffleOutputChannelFactory; + @MonotonicNonNull private ResultAndChannels workResultAndOutputChannels; + @MonotonicNonNull private SettableFuture stagePartitionBoundariesFuture; + @MonotonicNonNull private ListenableFuture stageOutputChannelsFuture; public RunWorkOrder( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index 14d30f666531..aeaae030e613 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -277,7 +277,8 @@ public static WorkerMemoryParameters createInstance( maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, - numProcessingThreadsInJvm + numProcessingThreadsInJvm, + maxConcurrentStages ) ); } @@ -300,7 +301,8 @@ public static WorkerMemoryParameters createInstance( maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, - numProcessingThreadsInJvm + numProcessingThreadsInJvm, + maxConcurrentStages ) ); } @@ -336,7 +338,8 @@ public static WorkerMemoryParameters createInstance( maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, - numProcessingThreadsInJvm + numProcessingThreadsInJvm, + maxConcurrentStages ) ); } @@ -533,6 +536,8 @@ private static long memoryPerBundle( final int numProcessingThreadsInJvm ) { + // One bundle per worker + one per processor. The worker bundles are used for sorting (SuperSorter) and the + // processing bundles are used for reading input and doing per-partition processing. final int bundleCount = numWorkersInJvm + numProcessingThreadsInJvm; // Need to subtract memoryForWorkers off the top of usableMemoryInJvm, since this is reserved for diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java index 5c80f065eef3..6f4b36da1eec 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.error; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; @@ -35,6 +36,7 @@ public class NotEnoughMemoryFault extends BaseMSQFault private final long usableMemory; private final int serverWorkers; private final int serverThreads; + private final int maxConcurrentStages; @JsonCreator public NotEnoughMemoryFault( @@ -42,19 +44,23 @@ public NotEnoughMemoryFault( @JsonProperty("serverMemory") final long serverMemory, @JsonProperty("usableMemory") final long usableMemory, @JsonProperty("serverWorkers") final int serverWorkers, - @JsonProperty("serverThreads") final int serverThreads + @JsonProperty("serverThreads") final int serverThreads, + @JsonProperty("maxConcurrentStages") final int maxConcurrentStages ) { super( CODE, "Not enough memory. Required at least %,d bytes. (total = %,d bytes; usable = %,d bytes; " - + "worker capacity = %,d; processing threads = %,d). Increase JVM memory with the -Xmx option" - + (serverWorkers > 1 ? " or reduce worker capacity on this server" : ""), + + "worker capacity = %,d; processing threads = %,d; concurrent stages = %,d). " + + "Increase JVM memory with the -Xmx option" + + (serverWorkers > 1 ? ", or reduce worker capacity on this server" : "") + + (maxConcurrentStages > 1 ? ", or reduce maxConcurrentStages for this query" : ""), suggestedServerMemory, serverMemory, usableMemory, serverWorkers, - serverThreads + serverThreads, + maxConcurrentStages ); this.suggestedServerMemory = suggestedServerMemory; @@ -62,6 +68,7 @@ public NotEnoughMemoryFault( this.usableMemory = usableMemory; this.serverWorkers = serverWorkers; this.serverThreads = serverThreads; + this.maxConcurrentStages = maxConcurrentStages; } @JsonProperty @@ -94,6 +101,13 @@ public int getServerThreads() return serverThreads; } + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public int getMaxConcurrentStages() + { + return maxConcurrentStages; + } + @Override public boolean equals(Object o) { @@ -107,12 +121,12 @@ public boolean equals(Object o) return false; } NotEnoughMemoryFault that = (NotEnoughMemoryFault) o; - return - suggestedServerMemory == that.suggestedServerMemory - && serverMemory == that.serverMemory - && usableMemory == that.usableMemory - && serverWorkers == that.serverWorkers - && serverThreads == that.serverThreads; + return suggestedServerMemory == that.suggestedServerMemory + && serverMemory == that.serverMemory + && usableMemory == that.usableMemory + && serverWorkers == that.serverWorkers + && serverThreads == that.serverThreads + && maxConcurrentStages == that.maxConcurrentStages; } @Override @@ -124,7 +138,8 @@ public int hashCode() serverMemory, usableMemory, serverWorkers, - serverThreads + serverThreads, + maxConcurrentStages ); } @@ -137,6 +152,7 @@ public String toString() " bytes, usableMemory=" + usableMemory + " bytes, serverWorkers=" + serverWorkers + ", serverThreads=" + serverThreads + + ", maxConcurrentStages=" + maxConcurrentStages + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java index 7e3ac5c7cac4..4e59e7d17a89 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java @@ -71,6 +71,8 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { + // Stages can transition to FINISHED even if they haven't generated all output yet. For example, this is + // possible if the downstream stage is applying a limit. return priorPhase.compareTo(FINISHED) < 0; } }, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java index d81de9438358..8a534f8f1497 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -200,6 +200,7 @@ public synchronized void close() { // Call channel.close() unless readLocally() has been called. In that case, we expect the caller to close it. if (state != State.LOCAL) { + state = State.CLOSED; channel.close(); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java index d4dd4b47e688..1ead2a181fd9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java @@ -42,12 +42,12 @@ public void test_oneWorkerInJvm_alone() MSQException.class, () -> create(1_000_000_000, 1, 32, 1, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32, 1), e.getFault()); final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 1, 0, 0)) .getFault(); - Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); + Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32, 1), fault); } @Test @@ -63,12 +63,12 @@ public void test_oneWorkerInJvm_alone_twoConcurrentStages() () -> create(1_000_000_000, 1, 12, 2, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(1_736_034_666, 1_000_000_000, 750_000_000, 1, 12), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(1_736_034_666, 1_000_000_000, 750_000_000, 1, 12, 2), e.getFault()); final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 2, 1, 0, 0)) .getFault(); - Assert.assertEquals(new NotEnoughMemoryFault(4_048_090_666L, 1_000_000_000, 750_000_000, 2, 32), fault); + Assert.assertEquals(new NotEnoughMemoryFault(4_048_090_666L, 1_000_000_000, 750_000_000, 2, 32, 2), fault); } @Test @@ -149,18 +149,18 @@ public void test_oneWorkerInJvm_smallWorkerCapacity() MSQException.class, () -> create(1_000_000_000, 1, 32, 1, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32, 1), e.getFault()); final MSQException e2 = Assert.assertThrows( MSQException.class, () -> create(128_000_000, 1, 4, 1, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(580_006_666, 12_8000_000, 96_000_000, 1, 4), e2.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(580_006_666, 12_8000_000, 96_000_000, 1, 4, 1), e2.getFault()); final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 1, 0, 0)) .getFault(); - Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); + Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32, 1), fault); } @Test @@ -197,7 +197,7 @@ public void test_oneWorkerInJvm_oneByteUsableMemory() () -> WorkerMemoryParameters.createInstance(1, 1, 1, 1, 32, 1, 1) ); - Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1, 1), e.getFault()); } @Test diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java index c33faa40c14e..cffc0f78a497 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java @@ -74,7 +74,7 @@ public void testFaultSerde() throws IOException )); assertFaultSerde(new InvalidNullByteFault("the source", 1, "the column", "the value", 2)); assertFaultSerde(new InvalidFieldFault("the source", "the column", 1, "the error", "the log msg")); - assertFaultSerde(new NotEnoughMemoryFault(1000, 1000, 900, 1, 2)); + assertFaultSerde(new NotEnoughMemoryFault(1000, 1000, 900, 1, 2, 2)); assertFaultSerde(QueryNotSupportedFault.INSTANCE); assertFaultSerde(new QueryRuntimeFault("new error", "base error")); assertFaultSerde(new QueryRuntimeFault("new error", null)); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index e65104302032..fdbb00ae1f5e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -262,7 +262,7 @@ public ListenableFuture cancelTask(String workerId) { final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { - worker.stop(); + worker.awaitStop(); } return Futures.immediateFuture(null); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index 2459a83ecfe1..619306abb38c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -138,6 +138,6 @@ public ListenableFuture fetchChannelData( @Override public void close() { - inMemoryWorkers.forEach((k, v) -> v.stop()); + inMemoryWorkers.forEach((k, v) -> v.awaitStop()); } } From 61c337b57bb866a91c9657503ef10fa2267d382b Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 29 Jul 2024 13:01:32 -0700 Subject: [PATCH 08/13] Improvements to docs and method names. --- .../org/apache/druid/msq/exec/Worker.java | 11 +++--- .../org/apache/druid/msq/exec/WorkerImpl.java | 2 +- .../apache/druid/msq/rpc/WorkerResource.java | 2 +- .../output/ChannelStageOutputReader.java | 31 ++++++++++++++- .../shuffle/output/FileStageOutputReader.java | 24 +++++++++++- .../shuffle/output/NilStageOutputReader.java | 2 +- .../msq/shuffle/output/StageOutputHolder.java | 23 +++++++++++ .../msq/shuffle/output/StageOutputReader.java | 38 +++++++++++++------ .../msq/indexing/WorkerChatHandlerTest.java | 2 +- .../druid/msq/test/MSQTestWorkerClient.java | 2 +- 10 files changed, 113 insertions(+), 24 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java index 9277bb5beaa1..a90068060d81 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java @@ -26,7 +26,6 @@ import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import java.io.IOException; import java.io.InputStream; /** @@ -100,14 +99,16 @@ public interface Worker * If the channel is finished, an empty {@link InputStream} is returned. * * With {@link OutputChannelMode#MEMORY}, once this method is called with a certain offset, workers are free to - * delete data prior to that offset. (It will not be re-requested.) + * delete data prior to that offset. (Already-requested offsets will not be re-requested, because + * {@link OutputChannelMode#MEMORY} requires a single reader.) In this mode, if an already-requested offset is + * re-requested for some reason, an error future is returned. * - * Returns future that resolves to null if worker output for a particular queryId, stageNumber, and + * The returned future resolves to null if stage output for a particular queryId, stageNumber, and * partitionNumber is not found. * - * @throws IOException when the worker output is found but there is an error while reading it. + * Throws an exception when worker output is found, but there is an error while reading it. */ - ListenableFuture readChannel(StageId stageId, int partitionNumber, long offset); + ListenableFuture readStageOutput(StageId stageId, int partitionNumber, long offset); /** * Returns a snapshot of counters. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 9cc9e3a32a4f..b7c1f3b84de3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -547,7 +547,7 @@ public void controllerFailed() } @Override - public ListenableFuture readChannel( + public ListenableFuture readStageOutput( final StageId stageId, final int partitionNumber, final long offset diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java index 00a99c524493..a0bfecff5427 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -101,7 +101,7 @@ public Response httpGetChannelData( MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); final ListenableFuture dataFuture = - worker.readChannel(new StageId(queryId, stageNumber), partitionNumber, offset); + worker.readStageOutput(new StageId(queryId, stageNumber), partitionNumber, offset); final AsyncContext asyncContext = req.startAsync(); asyncContext.setTimeout(GET_CHANNEL_DATA_TIMEOUT); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java index 8a534f8f1497..04df2766325f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -43,7 +43,7 @@ import java.util.Deque; /** - * Reader for {@link ReadableFrameChannel}. + * Reader for the case where stage output is a generic {@link ReadableFrameChannel}. * * Because this reader returns an underlying channel directly, it must only be used when it is certain that * only a single consumer exists, i.e., when using output mode {@link OutputChannelMode#MEMORY}. See @@ -98,6 +98,20 @@ public ChannelStageOutputReader(final ReadableFrameChannel channel) this.writer = FrameFileWriter.open(new ChunkAcceptor(), null, ByteTracker.unboundedTracker()); } + /** + * Returns an input stream starting at the provided offset. + * + * The returned {@link InputStream} is non-blocking, and is slightly buffered (up to one frame). It does not + * necessarily contain the complete remaining dataset; this means that multiple calls to this method are necessary + * to fetch the complete dataset. + * + * The provided offset must be greater than, or equal to, the offset provided to the prior call. + * + * This class supports either remote or local reads, but not both. Calling both this method and {@link #readLocally()} + * on the same instance of this class is an error. + * + * @param offset offset into the stage output stream + */ @Override public synchronized ListenableFuture readRemotelyFrom(final long offset) { @@ -179,6 +193,17 @@ public synchronized ListenableFuture readRemotelyFrom(final long of return Futures.immediateFuture(new ByteChunksInputStream(ImmutableList.copyOf(chunks), positionWithinFirstChunk)); } + /** + * Returns the {@link ReadableFrameChannel} that backs this reader. + * + * Callers are responsible for closing the returned channel. Once this method is called, the caller becomes the + * owner of the channel, and this class's {@link #close()} method will no longer close the channel. + * + * Only a single reader is supported. Once this method is called, it cannot be called again. + * + * This class supports either remote or local reads, but not both. Calling both this method and + * {@link #readRemotelyFrom(long)} on the same instance of this class is an error. + */ @Override public synchronized ReadableFrameChannel readLocally() { @@ -195,6 +220,10 @@ public synchronized ReadableFrameChannel readLocally() } } + /** + * Closes the {@link ReadableFrameChannel} backing this reader, unless {@link #readLocally()} has been called. + * In that case, the caller of {@link #readLocally()} is responsible for closing the channel. + */ @Override public synchronized void close() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java index 37f01a7a2544..29fb7b17ee78 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java @@ -33,7 +33,7 @@ import java.nio.channels.Channels; /** - * Reader for {@link FrameFile} on disk. + * Reader for the case where stage output is stored in a {@link FrameFile} on disk. */ public class FileStageOutputReader implements StageOutputReader { @@ -44,6 +44,16 @@ public FileStageOutputReader(FrameFile frameFile) this.frameFile = frameFile; } + /** + * Returns an input stream starting at the provided offset. The file is opened and seeked in-line with this method + * call, so the returned future is always immediately resolved. Callers are responsible for closing the returned + * input stream. + * + * This class supports remote and local reads from the same {@link FrameFile}, which, for example, is useful when + * broadcasting the output of a stage. + * + * @param offset offset into the stage output file + */ @Override public ListenableFuture readRemotelyFrom(long offset) { @@ -63,12 +73,24 @@ public ListenableFuture readRemotelyFrom(long offset) } } + /** + * Returns a channel pointing to a fresh {@link FrameFile#newReference()} of the underlying frame file. Callers are + * responsible for closing the returned channel. + * + * This class supports remote and local reads from the same {@link FrameFile}, which, for example, is useful when + * broadcasting the output of a stage. + */ @Override public ReadableFrameChannel readLocally() { return new ReadableFileFrameChannel(frameFile.newReference()); } + /** + * Closes the initial reference to the underlying {@link FrameFile}. Does not close additional references created by + * calls to {@link #readLocally()}; those references are closed when the channel(s) returned by {@link #readLocally()} + * are closed. + */ @Override public void close() throws IOException { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java index e12341f99f83..86530dad1d01 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java @@ -33,7 +33,7 @@ import java.nio.channels.Channels; /** - * Reader for empty channel. + * Reader for the case where stage output is known to be empty. */ public class NilStageOutputReader implements StageOutputReader { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java index 215facea3633..c19519dfb7bb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java @@ -27,8 +27,12 @@ import org.apache.druid.frame.channel.ReadableNilFrameChannel; import org.apache.druid.frame.file.FrameFile; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.rpc.WorkerResource; import org.apache.druid.utils.CloseableUtils; +import javax.servlet.http.HttpServletRequest; import java.io.Closeable; import java.io.InputStream; @@ -46,16 +50,35 @@ public StageOutputHolder() this.readerFuture = FutureUtils.transform(channelFuture, StageOutputHolder::createReader); } + /** + * Method for remote reads. + * + * Provides the implementation for {@link Worker#readStageOutput(StageId, int, long)}, which is in turn used by + * {@link WorkerResource#httpGetChannelData(String, int, int, long, HttpServletRequest)}. + * + * @see StageOutputReader#readRemotelyFrom(long) for details on behavior + */ public ListenableFuture readRemotelyFrom(final long offset) { return FutureUtils.transformAsync(readerFuture, reader -> reader.readRemotelyFrom(offset)); } + /** + * Method for local reads. + * + * Used instead of {@link #readRemotelyFrom(long)} when a worker is reading a channel from itself, to avoid needless + * HTTP calls to itself. + * + * @see StageOutputReader#readLocally() for details on behavior + */ public ReadableFrameChannel readLocally() { return new FutureReadableFrameChannel(FutureUtils.transform(readerFuture, StageOutputReader::readLocally)); } + /** + * Sets the channel that backs {@link #readLocally()} and {@link #readRemotelyFrom(long)}. + */ public void setChannel(final ReadableFrameChannel channel) { if (!channelFuture.set(channel)) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java index bad319135158..0cf572e6363e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java @@ -21,35 +21,49 @@ import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.shuffle.input.WorkerOrLocalInputChannelFactory; import java.io.Closeable; import java.io.InputStream; /** - * Interface for remotely reading output channels for a particular stage. Each instance of this interface represents a - * stream from a single {@link org.apache.druid.msq.kernel.StagePartition} in - * {@link org.apache.druid.frame.file.FrameFile} format. + * Interface for reading output channels for a particular stage. Each instance of this interface represents a + * stream from a single {@link org.apache.druid.msq.kernel.StagePartition} in {@link FrameFile} format. + * + * @see FileStageOutputReader implementation backed by {@link FrameFile} + * @see ChannelStageOutputReader implementation backed by {@link ReadableFrameChannel} + * @see NilStageOutputReader implementation for an empty channel */ public interface StageOutputReader extends Closeable { /** - * Returns an {@link InputStream} starting from a particular point in the - * {@link org.apache.druid.frame.file.FrameFile}. Length of the stream is implementation-dependent; it may or may - * not go all the way to the end of the file. Zero-length stream indicates EOF. Any nonzero length means you should - * call this method again with a higher offset. + * Method for remote reads. + * + * This method ultimately backs {@link Worker#readStageOutput(StageId, int, long)}. Refer to that method's + * documentation for details about behavior of the returned future. + * + * Callers are responsible for closing the returned {@link InputStream}. This input stream may encapsulate + * resources that are not closed by this class's {@link #close()} method. * - * @param offset offset into the frame file + * @param offset offset into the stage output file * - * @see org.apache.druid.msq.exec.WorkerImpl#readChannel(StageId, int, long) + * @see StageOutputHolder#readRemotelyFrom(long) which uses this method + * @see Worker#readStageOutput(StageId, int, long) for documentation on behavior of the returned future */ ListenableFuture readRemotelyFrom(long offset); /** - * Returns a {@link ReadableFrameChannel} for local reading. + * Method for local reads. + * + * Depending on implementation, this method may or may not be able to be called multiple times, and may or may not + * be able to be mixed with {@link #readRemotelyFrom(long)}. Refer to the specific implementation for more details. + * + * Callers are responsible for closing the returned channel. The returned channel may encapsulate resources that + * are not closed by this class's {@link #close()} method. * - * @see WorkerOrLocalInputChannelFactory#openChannel(StageId, int, int) + * @see StageOutputHolder#readLocally() which uses this method */ ReadableFrameChannel readLocally(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java index 9e61c9dd7b83..ccf91acb6667 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java @@ -181,7 +181,7 @@ public boolean postResultPartitionBoundaries( @Nullable @Override - public ListenableFuture readChannel(StageId stageId, int partitionNumber, long offset) + public ListenableFuture readStageOutput(StageId stageId, int partitionNumber, long offset) { return null; } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index 619306abb38c..f384397965b9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -118,7 +118,7 @@ public ListenableFuture fetchChannelData( ) { try (InputStream inputStream = - inMemoryWorkers.get(workerTaskId).readChannel(stageId, partitionNumber, offset).get()) { + inMemoryWorkers.get(workerTaskId).readStageOutput(stageId, partitionNumber, offset).get()) { byte[] buffer = new byte[8 * 1024]; boolean didRead = false; int bytesRead; From 112bd710d493671942d0a7924ca65e109f088fdb Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 29 Jul 2024 21:18:53 -0700 Subject: [PATCH 09/13] Update comments, add test. --- .../output/ChannelStageOutputReader.java | 5 +- .../output/ChannelStageOutputReaderTest.java | 255 ++++++++++++++++++ 2 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java index 04df2766325f..ec95ca7af6a7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -75,7 +75,7 @@ enum State private State state = State.INIT; /** - * Position within the overall stream. + * Position of {@link #positionWithinFirstChunk} in the first chunk of {@link #chunks}, within the overall stream. */ @GuardedBy("this") private long cursor; @@ -170,12 +170,13 @@ public synchronized ListenableFuture readRemotelyFrom(final long of } } - // Remove first chunk if it is no longer needed. (offset is entirely past it.) + // Advance cursor to the provided offset, or the end of the current chunk, whichever is earlier. final byte[] chunk = chunks.peek(); final long amountToAdvance = Math.min(offset - cursor, chunk.length - positionWithinFirstChunk); cursor += amountToAdvance; positionWithinFirstChunk += Ints.checkedCast(amountToAdvance); + // Remove first chunk if it is no longer needed. (i.e., if the cursor is at the end of it.) if (positionWithinFirstChunk == chunk.length) { chunks.poll(); positionWithinFirstChunk = 0; diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java new file mode 100644 index 000000000000..e27d1b69c607 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.io.ByteStreams; +import com.google.common.math.IntMath; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.FrameType; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.testutil.FrameSequenceBuilder; +import org.apache.druid.frame.testutil.FrameTestUtil; +import org.apache.druid.segment.TestIndex; +import org.apache.druid.segment.incremental.IncrementalIndex; +import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableCauseMatcher; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.RoundingMode; +import java.util.List; + +public class ChannelStageOutputReaderTest extends InitializedNullHandlingTest +{ + private static final int MAX_FRAMES = 10; + private static final int EXPECTED_NUM_ROWS = 1209; + + private final BlockingQueueFrameChannel channel = new BlockingQueueFrameChannel(MAX_FRAMES); + private final ChannelStageOutputReader reader = new ChannelStageOutputReader(channel.readable()); + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private FrameReader frameReader; + private List frameList; + + @Before + public void setUp() throws Exception + { + final IncrementalIndex index = TestIndex.getIncrementalTestIndex(); + final IncrementalIndexStorageAdapter adapter = new IncrementalIndexStorageAdapter(index); + frameReader = FrameReader.create(adapter.getRowSignature()); + frameList = FrameSequenceBuilder.fromAdapter(adapter) + .frameType(FrameType.ROW_BASED) + .maxRowsPerFrame(IntMath.divide(index.size(), MAX_FRAMES, RoundingMode.CEILING)) + .frames() + .toList(); + } + + @After + public void tearDown() throws Exception + { + reader.close(); + } + + @Test + public void test_readLocally() throws IOException + { + writeAllFramesToChannel(); + + Assert.assertSame(channel.readable(), reader.readLocally()); + reader.close(); // Won't close the channel, because it's already been returned by readLocally + + final int numRows = FrameTestUtil.readRowsFromFrameChannel(channel.readable(), frameReader).toList().size(); + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readLocally_closePriorToRead() throws IOException + { + writeAllFramesToChannel(); + + reader.close(); + + // Can't read the channel after closing the reader + Assert.assertThrows( + IllegalStateException.class, + reader::readLocally + ); + } + + @Test + public void test_readLocally_thenReadRemotely() throws IOException + { + writeAllFramesToChannel(); + + Assert.assertSame(channel.readable(), reader.readLocally()); + + // Can't read remotely after reading locally + Assert.assertThrows( + IllegalStateException.class, + () -> reader.readRemotelyFrom(0) + ); + + // Can still read locally after this error + final int numRows = FrameTestUtil.readRowsFromFrameChannel(channel.readable(), frameReader).toList().size(); + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readRemotely_strideBasedOnReturnedChunk() throws IOException + { + // Test that reads entire chunks from readRemotelyFrom. This is a typical usage pattern. + + writeAllFramesToChannel(); + + final File tmpFile = temporaryFolder.newFile(); + + try (final FileOutputStream tmpOut = new FileOutputStream(tmpFile)) { + int numReads = 0; + long offset = 0; + + while (true) { + try (final InputStream in = FutureUtils.getUnchecked(reader.readRemotelyFrom(offset), true)) { + numReads++; + final long bytesWritten = ByteStreams.copy(in, tmpOut); + offset += bytesWritten; + + if (bytesWritten == 0) { + break; + } + } + } + + MatcherAssert.assertThat(numReads, Matchers.greaterThan(1)); + } + + final FrameFile frameFile = FrameFile.open(tmpFile, null); + final int numRows = + FrameTestUtil.readRowsFromFrameChannel(new ReadableFileFrameChannel(frameFile), frameReader).toList().size(); + + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readRemotely_strideOneByte() throws IOException + { + // Test that reads one byte at a time from readRemotelyFrom. This helps ensure that there are no edge cases + // in the chunk-reading logic. + + writeAllFramesToChannel(); + + final File tmpFile = temporaryFolder.newFile(); + + try (final FileOutputStream tmpOut = new FileOutputStream(tmpFile)) { + int numReads = 0; + long offset = 0; + + while (true) { + try (final InputStream in = FutureUtils.getUnchecked(reader.readRemotelyFrom(offset), true)) { + numReads++; + final int nextByte = in.read(); + + if (nextByte < 0) { + break; + } + + tmpOut.write(nextByte); + offset++; + } + } + + Assert.assertEquals(numReads, offset + 1); + } + + final FrameFile frameFile = FrameFile.open(tmpFile, null); + final int numRows = + FrameTestUtil.readRowsFromFrameChannel(new ReadableFileFrameChannel(frameFile), frameReader).toList().size(); + + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readRemotely_thenLocally() throws IOException + { + writeAllFramesToChannel(); + + // Read remotely + FutureUtils.getUnchecked(reader.readRemotelyFrom(0), true); + + // Then read locally + Assert.assertThrows( + IllegalStateException.class, + reader::readLocally + ); + } + + @Test + public void test_readRemotely_cannotReverse() throws IOException + { + writeAllFramesToChannel(); + + // Read remotely from offset = 1. + final InputStream in = FutureUtils.getUnchecked(reader.readRemotelyFrom(1), true); + final int offset = ByteStreams.toByteArray(in).length; + MatcherAssert.assertThat(offset, Matchers.greaterThan(0)); + + // Then read again from offset = 0; should get an error. + final RuntimeException e = Assert.assertThrows( + RuntimeException.class, + () -> FutureUtils.getUnchecked(reader.readRemotelyFrom(0), true) + ); + + MatcherAssert.assertThat( + e, + ThrowableCauseMatcher.hasCause( + Matchers.allOf( + CoreMatchers.instanceOf(IllegalStateException.class), + ThrowableMessageMatcher.hasMessage(CoreMatchers.startsWith("Offset[0] no longer available")) + ) + ) + ); + } + + private void writeAllFramesToChannel() throws IOException + { + for (Frame frame : frameList) { + channel.writable().write(frame); + } + channel.writable().close(); + } +} From 193be9e4486fbbe9a0b1f319cae8272f8ef1afba Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 29 Jul 2024 21:41:05 -0700 Subject: [PATCH 10/13] Additional javadocs. --- .../druid/msq/shuffle/output/StageOutputReader.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java index 0cf572e6363e..36b993611ca4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java @@ -47,6 +47,10 @@ public interface StageOutputReader extends Closeable * Callers are responsible for closing the returned {@link InputStream}. This input stream may encapsulate * resources that are not closed by this class's {@link #close()} method. * + * It is implementation-dependent whether calls to this method must have monotonically increasing offsets. + * In particular, {@link ChannelStageOutputReader} requires monotonically increasing offsets, but + * {@link FileStageOutputReader} and {@link NilStageOutputReader} do not. + * * @param offset offset into the stage output file * * @see StageOutputHolder#readRemotelyFrom(long) which uses this method @@ -63,6 +67,10 @@ public interface StageOutputReader extends Closeable * Callers are responsible for closing the returned channel. The returned channel may encapsulate resources that * are not closed by this class's {@link #close()} method. * + * It is implementation-dependent whether this method can be called multiple times. In particular, + * {@link ChannelStageOutputReader#readLocally()} can only be called one time, but the implementations in + * {@link FileStageOutputReader} and {@link NilStageOutputReader} can be called multiple times. + * * @see StageOutputHolder#readLocally() which uses this method */ ReadableFrameChannel readLocally(); From f9f93213892c5ab884c80358608eb53c4fb27f15 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 29 Jul 2024 22:11:10 -0700 Subject: [PATCH 11/13] Fix throws. --- .../msq/shuffle/output/ChannelStageOutputReaderTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java index e27d1b69c607..927372a3a6ae 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java @@ -68,7 +68,7 @@ public class ChannelStageOutputReaderTest extends InitializedNullHandlingTest private List frameList; @Before - public void setUp() throws Exception + public void setUp() { final IncrementalIndex index = TestIndex.getIncrementalTestIndex(); final IncrementalIndexStorageAdapter adapter = new IncrementalIndexStorageAdapter(index); @@ -81,7 +81,7 @@ public void setUp() throws Exception } @After - public void tearDown() throws Exception + public void tearDown() { reader.close(); } From ac49b8e5f355cc91895c9ef56f5440fc00d660ac Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Tue, 30 Jul 2024 08:27:14 -0700 Subject: [PATCH 12/13] Fix worker stopping in tests. --- .../org/apache/druid/msq/test/MSQTestControllerContext.java | 1 + .../java/org/apache/druid/msq/test/MSQTestWorkerClient.java | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index fdbb00ae1f5e..72a9cdfd70b2 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -262,6 +262,7 @@ public ListenableFuture cancelTask(String workerId) { final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { + worker.stop(); worker.awaitStop(); } return Futures.immediateFuture(null); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index f384397965b9..e6048fbf600c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -138,6 +138,9 @@ public ListenableFuture fetchChannelData( @Override public void close() { - inMemoryWorkers.forEach((k, v) -> v.awaitStop()); + inMemoryWorkers.forEach((k, v) -> { + v.stop(); + v.awaitStop(); + }); } } From ff34edee774391852fad63402a8b404209ec84cd Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Tue, 30 Jul 2024 14:19:44 -0700 Subject: [PATCH 13/13] Fix stuck test. --- .../org/apache/druid/msq/exec/WorkerImpl.java | 18 +++++++++++++++++- .../msq/test/MSQTestControllerContext.java | 1 - .../druid/msq/test/MSQTestWorkerClient.java | 5 +---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index b7c1f3b84de3..7d2964eb2f8c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -96,6 +96,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -128,6 +129,11 @@ public class WorkerImpl implements Worker */ private final ConcurrentHashMap, CounterTracker> stageCounters = new ConcurrentHashMap<>(); + /** + * Atomic that is set to true when {@link #run()} starts (or when {@link #stop()} is called before {@link #run()}). + */ + private final AtomicBoolean didRun = new AtomicBoolean(); + /** * Future that resolves when {@link #run()} completes. */ @@ -165,6 +171,10 @@ public String id() @Override public void run() { + if (!didRun.compareAndSet(false, true)) { + throw new ISE("already run"); + } + try (final Closer closer = Closer.create()) { final KernelHolders kernelHolders = KernelHolders.create(context, closer); controllerClient = kernelHolders.getControllerClient(); @@ -526,7 +536,13 @@ public void stop() { // stopGracefully() is called when the containing process is terminated, or when the task is canceled. log.info("Worker id[%s] canceled.", context.workerId()); - doCancel(); + + if (didRun.compareAndSet(false, true)) { + // run() hasn't been called yet. Set runFuture so awaitStop() still works. + runFuture.set(null); + } else { + doCancel(); + } } @Override diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 72a9cdfd70b2..e65104302032 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -263,7 +263,6 @@ public ListenableFuture cancelTask(String workerId) final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { worker.stop(); - worker.awaitStop(); } return Futures.immediateFuture(null); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index e6048fbf600c..65145b5f5c01 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -138,9 +138,6 @@ public ListenableFuture fetchChannelData( @Override public void close() { - inMemoryWorkers.forEach((k, v) -> { - v.stop(); - v.awaitStop(); - }); + inMemoryWorkers.forEach((k, v) -> v.stop()); } }