From 5d1950d4512ff464975773d8c3eead44bd6b8930 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Tue, 30 Apr 2024 21:30:27 -0700 Subject: [PATCH] MSQ controller: Support in-memory shuffles; towards JVM reuse. (#16168) * MSQ controller: Support in-memory shuffles; towards JVM reuse. This patch contains two controller changes that make progress towards a lower-latency MSQ. First, support for in-memory shuffles. The main feature of in-memory shuffles, as far as the controller is concerned, is that they are not fully buffered. That means that whenever a producer stage uses in-memory output, its consumer must run concurrently. The controller determines which stages run concurrently, and when they start and stop. "Leapfrogging" allows any chain of sort-based stages to use in-memory shuffles even if we can only run two stages at once. For example, in a linear chain of stages 0 -> 1 -> 2 where all do sort-based shuffles, we can use in-memory shuffling for each one while only running two at once. (When stage 1 is done reading input and about to start writing its output, we can stop 0 and start 2.) 1) New OutputChannelMode enum attached to WorkOrders that tells workers whether stage output should be in memory (MEMORY), or use local or durable storage. 2) New logic in the ControllerQueryKernel to determine which stages can use in-memory shuffling (ControllerUtils#computeStageGroups) and to launch them at the appropriate time (ControllerQueryKernel#createNewKernels). 3) New "doneReadingInput" method on Controller (passed down to the stage kernels) which allows stages to transition to POST_READING even if they are not gathering statistics. This is important because it enables "leapfrogging" for HASH_LOCAL_SORT shuffles, and for GLOBAL_SORT shuffles with 1 partition. 4) Moved result-reading from ControllerContext#writeReports to new QueryListener interface, which ControllerImpl feeds results to row-by-row while the query is still running. Important so we can read query results from the final stage using an in-memory channel. 5) New class ControllerQueryKernelConfig holds configs that control kernel behavior (such as whether to pipeline, maximum number of concurrent stages, etc). Generated by the ControllerContext. Second, a refactor towards running workers in persistent JVMs that are able to cache data across queries. This is helpful because I believe we'll want to reuse JVMs and cached data for latency reasons. 1) Move creation of WorkerManager and TableInputSpecSlicer to the ControllerContext, rather than ControllerImpl. This allows managing workers and work assignment differently when JVMs are reusable. 2) Lift the Controller Jersey resource out from ControllerChatHandler to a reusable resource. 3) Move memory introspection to a MemoryIntrospector interface, and introduce ControllerMemoryParameters that uses it. This makes it easier to run MSQ in process types other than Indexer and Peon. Both of these areas will have follow-ups that make similar changes on the worker side. * Address static checks. * Address static checks. * Fixes. * Report writer tests. * Adjustments. * Fix reports. * Review updates. * Adjust name. * Small changes. --- extensions-core/multi-stage-query/pom.xml | 12 + .../org/apache/druid/msq/exec/Controller.java | 68 +- .../druid/msq/exec/ControllerClient.java | 16 + .../druid/msq/exec/ControllerContext.java | 63 +- .../apache/druid/msq/exec/ControllerImpl.java | 1308 ++++++++--------- .../msq/exec/ControllerMemoryParameters.java | 109 ++ .../exec/ControllerQueryResultsReader.java | 158 ++ .../exec/ExceptionWrappingWorkerClient.java | 10 +- .../org/apache/druid/msq/exec/Limits.java | 7 +- .../druid/msq/exec/MemoryIntrospector.java | 65 + .../msq/exec/MemoryIntrospectorImpl.java | 140 ++ .../druid/msq/exec/OutputChannelMode.java | 92 ++ .../apache/druid/msq/exec/QueryListener.java | 71 + .../apache/druid/msq/exec/ResultsContext.java | 86 ++ .../msq/exec/RetryCapableWorkerManager.java | 45 + .../msq/exec/SegmentLoadStatusFetcher.java | 6 +- .../apache/druid/msq/exec/WorkerClient.java | 41 +- .../WorkerFailureListener.java} | 15 +- .../apache/druid/msq/exec/WorkerManager.java | 92 ++ .../msq/exec/WorkerMemoryParameters.java | 17 +- .../druid/msq/exec/WorkerSketchFetcher.java | 48 +- .../apache/druid/msq/exec/WorkerStats.java | 105 ++ .../guice/IndexerMemoryManagementModule.java | 78 + .../druid/msq/guice/MSQIndexingModule.java | 20 - .../apache/druid/msq/guice/MSQSqlModule.java | 5 +- .../druid/msq/guice/MultiStageQuery.java | 3 + .../msq/guice/PeonMemoryManagementModule.java | 85 ++ .../apache/druid/msq/guice/SqlTaskModule.java | 10 - .../indexing/IndexerControllerContext.java | 199 ++- .../IndexerResourcePermissionMapper.java | 55 + .../druid/msq/indexing/MSQControllerTask.java | 39 +- .../msq/indexing/MSQWorkerTaskLauncher.java | 231 ++- .../msq/indexing/TaskReportQueryListener.java | 214 +++ .../client/ControllerChatHandler.java | 175 +-- .../client/IndexerControllerClient.java | 16 + .../indexing/client/IndexerWorkerClient.java | 252 +--- .../destination/DataSourceMSQDestination.java | 12 + .../DurableStorageMSQDestination.java | 13 + .../destination/ExportMSQDestination.java | 12 + .../indexing/destination/MSQDestination.java | 26 + .../destination/MSQSelectDestination.java | 31 +- .../destination/TaskReportMSQDestination.java | 12 + .../indexing/error/NotEnoughMemoryFault.java | 5 +- .../msq/indexing/report/MSQResultsReport.java | 76 +- .../msq/indexing/report/MSQStagesReport.java | 29 +- .../msq/indexing/report/MSQStatusReport.java | 29 +- .../indexing/report/MSQTaskReportPayload.java | 13 +- .../druid/msq/input/InputSpecSlicer.java | 3 +- .../msq/input/InputSpecSlicerFactory.java | 6 +- .../msq/input/stage/StageInputSlice.java | 25 +- .../msq/input/stage/StageInputSpecSlicer.java | 25 +- .../msq/input/table/TableInputSpecSlicer.java | 121 +- .../kernel/GlobalSortMaxCountShuffleSpec.java | 6 +- .../msq/kernel/GlobalSortShuffleSpec.java | 6 + .../druid/msq/kernel/HashShuffleSpec.java | 7 +- .../druid/msq/kernel/MixShuffleSpec.java | 6 - .../druid/msq/kernel/QueryDefinition.java | 6 +- .../msq/kernel/QueryDefinitionBuilder.java | 15 +- .../apache/druid/msq/kernel/ShuffleKind.java | 35 +- .../apache/druid/msq/kernel/ShuffleSpec.java | 12 +- .../druid/msq/kernel/StageDefinition.java | 2 +- .../org/apache/druid/msq/kernel/StageId.java | 8 +- .../apache/druid/msq/kernel/WorkOrder.java | 72 +- .../msq/kernel/WorkerAssignmentStrategy.java | 6 +- .../controller/ControllerQueryKernel.java | 536 ++++--- .../ControllerQueryKernelConfig.java | 260 ++++ .../ControllerQueryKernelUtils.java | 406 +++++ .../controller/ControllerStagePhase.java | 124 +- .../controller/ControllerStageTracker.java | 99 +- .../ControllerWorkerStagePhase.java | 3 +- .../msq/kernel/controller/StageGroup.java | 133 ++ .../BaseLeafFrameProcessorFactory.java | 3 +- .../druid/msq/querykit/DataSourcePlan.java | 6 +- .../msq/querykit/ShuffleSpecFactory.java | 3 +- .../msq/querykit/WindowOperatorQueryKit.java | 4 +- .../msq/querykit/groupby/GroupByQueryKit.java | 4 +- .../druid/msq/querykit/scan/ScanQueryKit.java | 2 +- .../druid/msq/rpc/BaseWorkerClientImpl.java | 270 ++++ .../druid/msq/rpc/ControllerResource.java | 196 +++ .../druid/msq/rpc/MSQResourceUtils.java | 50 + .../rpc/ResourcePermissionMapper.java} | 21 +- .../DurableStorageInputChannelFactory.java | 6 +- .../sql/resources/SqlStatementResource.java | 45 +- .../msq/sql/resources/SqlTaskResource.java | 4 +- .../PartialKeyStatisticsInformation.java | 32 + .../msq/util/MultiStageQueryContext.java | 19 + .../msq/util/SqlStatementResourceHelper.java | 46 +- ...rg.apache.druid.initialization.DruidModule | 4 +- .../src/main/resources/log4j2.xml | 3 + .../druid/msq/exec/ControllerImplTest.java | 4 + .../exec/ControllerMemoryParametersTest.java | 121 ++ .../apache/druid/msq/exec/MSQFaultsTest.java | 12 +- .../apache/druid/msq/exec/MSQInsertTest.java | 2 +- .../apache/druid/msq/exec/MSQReplaceTest.java | 24 +- .../apache/druid/msq/exec/MSQSelectTest.java | 5 +- .../apache/druid/msq/exec/MSQTasksTest.java | 74 +- .../druid/msq/exec/QueryValidatorTest.java | 5 +- .../msq/exec/WorkerSketchFetcherTest.java | 39 +- .../indexing/MSQWorkerTaskLauncherTest.java | 4 +- .../indexing/TaskReportQueryListenerTest.java | 206 +++ .../msq/indexing/WorkerChatHandlerTest.java | 20 +- .../client/ControllerChatHandlerTest.java | 17 +- .../indexing/report/MSQTaskReportTest.java | 20 +- .../msq/input/stage/StageInputSliceTest.java | 4 +- .../input/stage/StageInputSpecSlicerTest.java | 29 +- .../input/table/TableInputSpecSlicerTest.java | 45 +- .../druid/msq/kernel/QueryDefinitionTest.java | 4 +- .../BaseControllerQueryKernelTest.java | 65 +- .../controller/ControllerQueryKernelTest.java | 248 +++- .../ControllerQueryKernelUtilsTest.java | 551 +++++++ .../MockQueryDefinitionBuilder.java | 117 +- ...onShufflingWorkersWithRetryKernelTest.java | 16 +- .../ShufflingWorkersWithRetryKernelTest.java | 17 +- .../kernel/controller/WorkerInputsTest.java | 16 +- .../resources/SqlStatementResourceTest.java | 11 +- ...tialKeyStatisticsInformationSerdeTest.java | 62 + .../druid/msq/test/CalciteMSQTestsHelper.java | 2 + .../test/CalciteSelectJoinQueryMSQTest.java | 9 +- .../apache/druid/msq/test/MSQTestBase.java | 31 +- .../msq/test/MSQTestControllerClient.java | 6 + .../msq/test/MSQTestControllerContext.java | 159 +- .../test/MSQTestOverlordServiceClient.java | 141 +- .../druid/msq/test/MSQTestWorkerClient.java | 25 +- .../druid/msq/test/MSQTestWorkerContext.java | 17 +- .../druid/msq/test/NoopQueryListener.java | 61 + .../util/SqlStatementResourceHelperTest.java | 25 +- .../MultipleFileTaskReportFileWriter.java | 36 +- .../MultipleFileTaskReportFileWriterTest.java | 64 + .../SingleFileTaskReportFileWriterTest.java | 61 + .../task/NoopTestTaskReportFileWriter.java | 10 + .../druid/testsEx/msq/ITMultiStageQuery.java | 6 +- .../testing/utils/MsqTestQueryHelper.java | 6 +- .../druid/frame/util/DurableStorageUtils.java | 6 +- .../SingleFileTaskReportFileWriter.java | 25 +- .../indexer/report/TaskReportFileWriter.java | 5 + .../org/apache/druid/rpc/RequestBuilder.java | 48 +- .../org/apache/druid/rpc/ServiceLocation.java | 58 +- 137 files changed, 6981 insertions(+), 2412 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java rename extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/{indexing/RetryTask.java => exec/WorkerFailureListener.java} (72%) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java rename extensions-core/multi-stage-query/src/main/java/org/apache/druid/{guice/annotations/MSQ.java => msq/rpc/ResourcePermissionMapper.java} (59%) create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java create mode 100644 indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java create mode 100644 indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java diff --git a/extensions-core/multi-stage-query/pom.xml b/extensions-core/multi-stage-query/pom.xml index 8939018661ce..9e637acff009 100644 --- a/extensions-core/multi-stage-query/pom.xml +++ b/extensions-core/multi-stage-query/pom.xml @@ -186,6 +186,11 @@ datasketches-memory provided + + it.unimi.dsi + fastutil + provided + it.unimi.dsi fastutil-core @@ -288,6 +293,13 @@ test-jar test + + org.apache.druid + druid-indexing-service + ${project.parent.version} + test-jar + test + org.apache.druid druid-sql diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java index 5e23a42b2fa1..f04286dd7c42 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java @@ -19,64 +19,42 @@ package org.apache.druid.msq.exec; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.msq.counters.CounterSnapshots; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; import java.util.List; /** - * Interface for the controller of a multi-stage query. + * Interface for the controller of a multi-stage query. Each Controller is specific to a particular query. + * + * @see WorkerImpl the production implementation */ public interface Controller { - /** - * POJO for capturing the status of a controller task that is currently running. - */ - class RunningControllerStatus - { - private final String id; - - @JsonCreator - public RunningControllerStatus(String id) - { - this.id = id; - } - - @JsonProperty("id") - public String getId() - { - return id; - } - } - /** * Unique task/query ID for the batch query run by this controller. + * + * Controller IDs must be globally unique. For tasks, this is the task ID from {@link MSQControllerTask#getId()}. */ - String id(); - - /** - * The task which this controller runs. - */ - MSQControllerTask task(); + String queryId(); /** * Runs the controller logic in the current thread. Surrounding classes provide the execution thread. */ - TaskStatus run() throws Exception; + void run(QueryListener listener) throws Exception; /** - * Terminate the query DAG upon a cancellation request. + * Terminate the controller upon a cancellation request. Causes a concurrently-running {@link #run} method in + * a separate thread to cancel all outstanding work and exit. */ - void stopGracefully(); + void stop(); // Worker-to-controller messages @@ -84,13 +62,29 @@ public String getId() * Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key * statistics have been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate partiton boundaries. * This is intended to be called by the {@link ControllerChatHandler}. + * + * @see ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation) + */ + void updatePartialKeyStatisticsInformation( + int stageNumber, + int workerNumber, + Object partialKeyStatisticsInformationObject + ); + + /** + * Sent by workers when they finish reading their input, in cases where they would not otherwise be calling + * {@link #updatePartialKeyStatisticsInformation(int, int, Object)}. + * + * @see ControllerClient#postDoneReadingInput(StageId, int) */ - void updatePartialKeyStatisticsInformation(int stageNumber, int workerNumber, Object partialKeyStatisticsInformationObject); + void doneReadingInput(int stageNumber, int workerNumber); /** * System error reported by a subtask. Note that the errors are organized by * taskId, not by query/stage/worker, because system errors are associated * with a task rather than a specific query/stage/worker execution context. + * + * @see ControllerClient#postWorkerError(String, MSQErrorReport) */ void workerError(MSQErrorReport errorReport); @@ -98,16 +92,22 @@ public String getId() * System warning reported by a subtask. Indicates that the worker has encountered a non-lethal error. Worker should * continue its execution in such a case. If the worker wants to report an error and stop its execution, * please use {@link Controller#workerError} + * + * @see ControllerClient#postWorkerWarning(List) */ void workerWarning(List errorReports); /** * Periodic update of {@link CounterSnapshots} from subtasks. + * + * @see ControllerClient#postCounters(String, CounterSnapshotsTree) */ void updateCounters(String taskId, CounterSnapshotsTree snapshotsTree); /** * Reports that results are ready for a subtask. + * + * @see ControllerClient#postResultsComplete(StageId, int, Object) */ void resultsComplete( String queryId, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java index afd1ece4dad1..405ff4fb9026 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java @@ -21,6 +21,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; @@ -43,6 +44,21 @@ void postPartialKeyStatistics( PartialKeyStatisticsInformation partialKeyStatisticsInformation ) throws IOException; + /** + * Client side method to tell the controller that a particular stage and worker is done reading its input. + * + * The main purpose of this call is to let the controller know when it can stop running the input stage. This helps + * execution roll smoothly from stage to stage during pipelined execution. For backwards-compatibility reasons + * (this is a newer method, only really useful when pipelining), this call should be skipped if the query is not + * pipelining stages. + * + * Only used when {@link StageDefinition#doesSortDuringShuffle()} and *not* + * {@link StageDefinition#mustGatherResultKeyStatistics()}. When the stage gathers result key statistics, workers + * call {@link #postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)} instead, which has the same + * effect of telling the controller that the worker is done reading its input. + */ + void postDoneReadingInput(StageId stageId, int workerNumber) throws IOException; + /** * Client-side method to update the controller with counters for a particular stage and worker. The controller uses * this to compile live reports, track warnings generated etc. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java index 0aa90688b910..40b114511c28 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java @@ -21,24 +21,44 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; -import org.apache.druid.client.coordinator.CoordinatorClient; -import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexing.common.actions.TaskActionClient; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.TableInputSpec; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; import org.apache.druid.server.DruidNode; /** - * Context used by multi-stage query controllers. - * - * Useful because it allows test fixtures to provide their own implementations. + * Context used by multi-stage query controllers. Useful because it allows test fixtures to provide their own + * implementations. */ public interface ControllerContext { - ServiceEmitter emitter(); + /** + * Configuration for {@link org.apache.druid.msq.kernel.controller.ControllerQueryKernel}. + */ + ControllerQueryKernelConfig queryKernelConfig(MSQSpec querySpec, QueryDefinition queryDef); + /** + * Callback from the controller implementation to "register" the controller. Used in the indexing task implementation + * to set up the task chat web service. + */ + void registerController(Controller controller, Closer closer); + + /** + * JSON-enabled object mapper. + */ ObjectMapper jsonMapper(); + /** + * Emit a metric using a {@link ServiceEmitter}. + */ + void emitMetric(String metric, Number value); + /** * Provides a way for tasks to request injectable objects. Useful because tasks are not able to request injection * at the time of server startup, because the server doesn't know what tasks it will be running. @@ -51,32 +71,33 @@ public interface ControllerContext DruidNode selfNode(); /** - * Provide access to the Coordinator service. + * Provides an {@link InputSpecSlicer} that slices {@link TableInputSpec} into {@link SegmentsInputSlice}. */ - CoordinatorClient coordinatorClient(); + InputSpecSlicer newTableInputSpecSlicer(); /** - * Provide access to segment actions in the Overlord. + * Provide access to segment actions in the Overlord. Only called for ingestion queries, i.e., where + * {@link MSQSpec#getDestination()} is {@link org.apache.druid.msq.indexing.destination.DataSourceMSQDestination}. */ TaskActionClient taskActionClient(); /** * Provides services about workers: starting, canceling, obtaining status. + * + * @param queryId query ID + * @param querySpec query spec + * @param queryKernelConfig config from {@link #queryKernelConfig(MSQSpec, QueryDefinition)} + * @param workerFailureListener listener that receives callbacks when workers fail */ - WorkerManagerClient workerManager(); - - /** - * Callback from the controller implementation to "register" the controller. Used in the indexing task implementation - * to set up the task chat web service. - */ - void registerController(Controller controller, Closer closer); + WorkerManager newWorkerManager( + String queryId, + MSQSpec querySpec, + ControllerQueryKernelConfig queryKernelConfig, + WorkerFailureListener workerFailureListener + ); /** * Client for communicating with workers. */ - WorkerClient taskClientFor(Controller controller); - /** - * Writes controller task report. - */ - void writeReports(String controllerTaskId, TaskReport.ReportMap reports); + WorkerClient newWorkerClient(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 682e2b484e4e..b10fbe76ecfa 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -25,23 +25,17 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; +import com.google.common.collect.Ordering; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.SettableFuture; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntSet; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.indexing.ClientCompactionTaskTransformSpec; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.data.input.StringTuple; @@ -51,7 +45,7 @@ import org.apache.druid.discovery.BrokerClient; import org.apache.druid.error.DruidException; import org.apache.druid.frame.allocation.ArenaMemoryAllocator; -import org.apache.druid.frame.channel.FrameChannelSequence; +import org.apache.druid.frame.channel.ReadableConcatFrameChannel; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartitions; @@ -64,11 +58,9 @@ import org.apache.druid.frame.write.InvalidFieldException; import org.apache.druid.frame.write.InvalidNullByteException; import org.apache.druid.indexer.TaskState; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.indexer.partitions.DimensionRangePartitionsSpec; import org.apache.druid.indexer.partitions.DynamicPartitionsSpec; import org.apache.druid.indexer.partitions.PartitionsSpec; -import org.apache.druid.indexer.report.TaskContextReport; import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexing.common.LockGranularity; import org.apache.druid.indexing.common.TaskLock; @@ -76,7 +68,6 @@ import org.apache.druid.indexing.common.actions.LockListAction; import org.apache.druid.indexing.common.actions.LockReleaseAction; import org.apache.druid.indexing.common.actions.MarkSegmentsAsUnusedAction; -import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; import org.apache.druid.indexing.common.actions.SegmentAllocateAction; import org.apache.druid.indexing.common.actions.SegmentTransactionalAppendAction; import org.apache.druid.indexing.common.actions.SegmentTransactionalInsertAction; @@ -88,6 +79,7 @@ import org.apache.druid.indexing.common.task.batch.parallel.TombstoneHelper; import org.apache.druid.indexing.overlord.SegmentPublishResult; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; @@ -97,9 +89,6 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.counters.CounterSnapshots; @@ -109,13 +98,11 @@ import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQTuningConfig; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; import org.apache.druid.msq.indexing.WorkerCount; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.ExportMSQDestination; -import org.apache.druid.msq.indexing.destination.MSQSelectDestination; import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; @@ -129,16 +116,13 @@ import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.MSQException; import org.apache.druid.msq.indexing.error.MSQFault; -import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher; -import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.indexing.error.QueryNotSupportedFault; import org.apache.druid.msq.indexing.error.TooManyBucketsFault; import org.apache.druid.msq.indexing.error.TooManyWarningsFault; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; import org.apache.druid.msq.indexing.processor.SegmentGeneratorFrameProcessorFactory; -import org.apache.druid.msq.indexing.report.MSQResultsReport; import org.apache.druid.msq.indexing.report.MSQSegmentReport; import org.apache.druid.msq.indexing.report.MSQStagesReport; import org.apache.druid.msq.indexing.report.MSQStatusReport; @@ -160,9 +144,7 @@ import org.apache.druid.msq.input.stage.StageInputSlice; import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpecSlicer; -import org.apache.druid.msq.input.table.DataSegmentWithLocation; import org.apache.druid.msq.input.table.TableInputSpec; -import org.apache.druid.msq.input.table.TableInputSpecSlicer; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.StageDefinition; @@ -170,9 +152,9 @@ import org.apache.druid.msq.kernel.StagePartition; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; import org.apache.druid.msq.kernel.controller.ControllerStagePhase; import org.apache.druid.msq.kernel.controller.WorkerInputs; -import org.apache.druid.msq.querykit.DataSegmentTimelineView; import org.apache.druid.msq.querykit.MultiQueryKit; import org.apache.druid.msq.querykit.QueryKit; import org.apache.druid.msq.querykit.QueryKitUtils; @@ -191,8 +173,6 @@ import org.apache.druid.msq.util.MSQFutureUtils; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.PassthroughAggregatorFactory; -import org.apache.druid.msq.util.SqlStatementResourceHelper; -import org.apache.druid.query.DruidMetrics; import org.apache.druid.query.Query; import org.apache.druid.query.QueryContext; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -213,19 +193,17 @@ import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec; import org.apache.druid.segment.transform.TransformSpec; import org.apache.druid.server.DruidNode; -import org.apache.druid.server.coordination.DruidServerMetadata; -import org.apache.druid.sql.calcite.planner.ColumnMapping; import org.apache.druid.sql.calcite.planner.ColumnMappings; import org.apache.druid.sql.calcite.rel.DruidQuery; import org.apache.druid.sql.http.ResultFormat; import org.apache.druid.storage.ExportStorageProvider; import org.apache.druid.timeline.CompactionState; import org.apache.druid.timeline.DataSegment; -import org.apache.druid.timeline.SegmentTimeline; import org.apache.druid.timeline.partition.DimensionRangeShardSpec; import org.apache.druid.timeline.partition.NumberedPartialShardSpec; import org.apache.druid.timeline.partition.NumberedShardSpec; import org.apache.druid.timeline.partition.ShardSpec; +import org.apache.druid.utils.CloseableUtils; import org.apache.druid.utils.CollectionUtils; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -234,7 +212,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -251,8 +228,6 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -264,8 +239,11 @@ public class ControllerImpl implements Controller { private static final Logger log = new Logger(ControllerImpl.class); - private final MSQControllerTask task; + private final String queryId; + private final MSQSpec querySpec; + private final ResultsContext resultsContext; private final ControllerContext context; + private volatile ControllerQueryKernelConfig queryKernelConfig; /** * Queue of "commands" to run on the {@link ControllerQueryKernel}. Various threads insert into the queue @@ -308,88 +286,61 @@ public class ControllerImpl implements Controller // For live reports. Written by the main controller thread, read by HTTP threads. private final ConcurrentHashMap stagePartitionCountsForLiveReports = new ConcurrentHashMap<>(); - private WorkerSketchFetcher workerSketchFetcher; - // Time at which the query started. + // Stage number -> output channel mode. Only set for stages that have started. // For live reports. Written by the main controller thread, read by HTTP threads. + private final ConcurrentHashMap stageOutputChannelModesForLiveReports = + new ConcurrentHashMap<>(); + + private WorkerSketchFetcher workerSketchFetcher; // WorkerNumber -> WorkOrders which need to be retried and our determined by the controller. // Map is always populated in the main controller thread by addToRetryQueue, and pruned in retryFailedTasks. private final Map> workOrdersToRetry = new HashMap<>(); + + // Time at which the query started. + // For live reports. Written by the main controller thread, read by HTTP threads. private volatile DateTime queryStartTime = null; private volatile DruidNode selfDruidNode; - private volatile MSQWorkerTaskLauncher workerTaskLauncher; + private volatile WorkerManager workerManager; private volatile WorkerClient netClient; private volatile FaultsExceededChecker faultsExceededChecker = null; private Map stageToStatsMergingMode; - private WorkerMemoryParameters workerMemoryParameters; - private boolean isDurableStorageEnabled; - private final boolean isFaultToleranceEnabled; - private final boolean isFailOnEmptyInsertEnabled; private volatile SegmentLoadStatusFetcher segmentLoadWaiter; @Nullable private MSQSegmentReport segmentReport; public ControllerImpl( - final MSQControllerTask task, - final ControllerContext context + final String queryId, + final MSQSpec querySpec, + final ResultsContext resultsContext, + final ControllerContext controllerContext ) { - this.task = task; - this.context = context; - this.isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled( - task.getQuerySpec().getQuery().context() - ); - this.isFaultToleranceEnabled = MultiStageQueryContext.isFaultToleranceEnabled( - task.getQuerySpec().getQuery().context() - ); - this.isFailOnEmptyInsertEnabled = MultiStageQueryContext.isFailOnEmptyInsertEnabled( - task.getQuerySpec().getQuery().context() - ); - } - - @Override - public String id() - { - return task.getId(); + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.querySpec = Preconditions.checkNotNull(querySpec, "querySpec"); + this.resultsContext = Preconditions.checkNotNull(resultsContext, "resultsContext"); + this.context = Preconditions.checkNotNull(controllerContext, "controllerContext"); } @Override - public MSQControllerTask task() + public String queryId() { - return task; + return queryId; } @Override - public TaskStatus run() throws Exception + public void run(final QueryListener queryListener) throws Exception { - final Closer closer = Closer.create(); - - try { - return runTask(closer); - } - catch (Throwable e) { - try { - closer.close(); - } - catch (Throwable e2) { - e.addSuppressed(e2); - } - - // We really don't expect this to error out. runTask should handle everything nicely. If it doesn't, something - // strange happened, so log it. - log.warn(e, "Encountered unhandled controller exception."); - return TaskStatus.failure(id(), e.toString()); - } - finally { - closer.close(); + try (final Closer closer = Closer.create()) { + runInternal(queryListener, closer); } } @Override - public void stopGracefully() + public void stop() { final QueryDefinition queryDef = queryDefRef.get(); @@ -403,18 +354,17 @@ public void stopGracefully() } ); - if (workerTaskLauncher != null) { - workerTaskLauncher.stop(true); + if (workerManager != null) { + workerManager.stop(true); } } - public TaskStatus runTask(final Closer closer) + private void runInternal(final QueryListener queryListener, final Closer closer) { QueryDefinition queryDef = null; ControllerQueryKernel queryKernel = null; ListenableFuture workerTaskRunnerFuture = null; CounterSnapshotsTree countersSnapshot = null; - Yielder resultsYielder = null; Throwable exceptionEncountered = null; final TaskState taskStateForReport; @@ -423,17 +373,24 @@ public TaskStatus runTask(final Closer closer) try { // Planning-related: convert the native query from MSQSpec into a multi-stage QueryDefinition. this.queryStartTime = DateTimes.nowUtc(); + context.registerController(this, closer); queryDef = initializeQueryDefAndState(closer); - final InputSpecSlicerFactory inputSpecSlicerFactory = makeInputSpecSlicerFactory(makeDataSegmentTimelineView()); - // Execution-related: run the multi-stage QueryDefinition. + final InputSpecSlicerFactory inputSpecSlicerFactory = + makeInputSpecSlicerFactory(context.newTableInputSpecSlicer()); + final Pair> queryRunResult = - new RunQueryUntilDone(queryDef, inputSpecSlicerFactory, closer).run(); + new RunQueryUntilDone( + queryDef, + queryKernelConfig, + inputSpecSlicerFactory, + queryListener, + closer + ).run(); queryKernel = Preconditions.checkNotNull(queryRunResult.lhs); workerTaskRunnerFuture = Preconditions.checkNotNull(queryRunResult.rhs); - resultsYielder = getFinalResultsYielder(queryDef, queryKernel); handleQueryResults(queryDef, queryKernel); } catch (Throwable e) { @@ -458,20 +415,24 @@ public TaskStatus runTask(final Closer closer) } else { // Query failure. Generate an error report and log the error(s) we encountered. final String selfHost = MSQTasks.getHostFromSelfNode(selfDruidNode); - final MSQErrorReport controllerError = - exceptionEncountered != null - ? MSQErrorReport.fromException( - id(), - selfHost, - null, - exceptionEncountered, - task.getQuerySpec().getColumnMappings() - ) - : null; + final MSQErrorReport controllerError; + + if (exceptionEncountered != null) { + controllerError = MSQErrorReport.fromException( + queryId(), + selfHost, + null, + exceptionEncountered, + querySpec.getColumnMappings() + ); + } else { + controllerError = null; + } + MSQErrorReport workerError = workerErrorRef.get(); taskStateForReport = TaskState.FAILED; - errorForReport = MSQTasks.makeErrorReport(id(), selfHost, controllerError, workerError); + errorForReport = MSQTasks.makeErrorReport(queryId(), selfHost, controllerError, workerError); // Log the errors we encountered. if (controllerError != null) { @@ -482,33 +443,14 @@ public TaskStatus runTask(final Closer closer) log.warn("Worker: %s", MSQTasks.errorReportToLogMessage(workerError)); } } - MSQResultsReport resultsReport = null; if (queryKernel != null && queryKernel.isSuccess()) { // If successful, encourage the tasks to exit successfully. - // get results before posting finish to the tasks. - if (resultsYielder != null) { - resultsReport = makeResultsTaskReport( - queryDef, - resultsYielder, - task.getQuerySpec().getColumnMappings(), - task.getSqlTypeNames(), - MultiStageQueryContext.getSelectDestination(task.getQuerySpec().getQuery().context()) - ); - try { - resultsYielder.close(); - } - catch (IOException e) { - throw new RuntimeException("Unable to fetch results of various worker tasks successfully", e); - } - } else { - resultsReport = null; - } postFinishToAllTasks(); - workerTaskLauncher.stop(false); + workerManager.stop(false); } else { // If not successful, cancel running tasks. - if (workerTaskLauncher != null) { - workerTaskLauncher.stop(true); + if (workerManager != null) { + workerManager.stop(true); } } @@ -523,10 +465,13 @@ public TaskStatus runTask(final Closer closer) } } - boolean shouldWaitForSegmentLoad = MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context()); + boolean shouldWaitForSegmentLoad = MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context()); try { - releaseTaskLocks(); + if (MSQControllerTask.isIngestion(querySpec)) { + releaseTaskLocks(); + } + cleanUpDurableStorageIfNeeded(); if (queryKernel != null && queryKernel.isSuccess()) { @@ -534,7 +479,7 @@ public TaskStatus runTask(final Closer closer) // If successful, there are segments created and segment load is enabled, segmentLoadWaiter should wait // for them to become available. log.info("Controller will now wait for segments to be loaded. The query has already finished executing," - + " and results will be included once the segments are loaded, even if this query is cancelled now."); + + " and results will be included once the segments are loaded, even if this query is canceled now."); segmentLoadWaiter.waitForSegmentsToLoad(); } } @@ -544,70 +489,54 @@ public TaskStatus runTask(final Closer closer) log.warn(e, "Exception thrown during cleanup. Ignoring it and writing task report."); } - try { - // Write report even if something went wrong. - final MSQStagesReport stagesReport; - - if (queryDef != null) { - final Map stagePhaseMap; - - if (queryKernel != null) { - // Once the query finishes, cleanup would have happened for all the stages that were successful - // Therefore we mark it as done to make the reports prettier and more accurate - queryKernel.markSuccessfulTerminalStagesAsFinished(); - stagePhaseMap = queryKernel.getActiveStages() - .stream() - .collect( - Collectors.toMap(StageId::getStageNumber, queryKernel::getStagePhase) - ); - } else { - stagePhaseMap = Collections.emptyMap(); - } + // Generate report even if something went wrong. + final MSQStagesReport stagesReport; - stagesReport = makeStageReport( - queryDef, - stagePhaseMap, - stageRuntimesForLiveReports, - stageWorkerCountsForLiveReports, - stagePartitionCountsForLiveReports - ); + if (queryDef != null) { + final Map stagePhaseMap; + + if (queryKernel != null) { + // Once the query finishes, cleanup would have happened for all the stages that were successful + // Therefore we mark it as done to make the reports prettier and more accurate + queryKernel.markSuccessfulTerminalStagesAsFinished(); + stagePhaseMap = queryKernel.getActiveStages() + .stream() + .collect( + Collectors.toMap(StageId::getStageNumber, queryKernel::getStagePhase) + ); } else { - stagesReport = null; - } - - final MSQTaskReportPayload taskReportPayload = new MSQTaskReportPayload( - makeStatusReport( - taskStateForReport, - errorForReport, - workerWarnings, - queryStartTime, - new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), - workerTaskLauncher, - segmentLoadWaiter, - segmentReport - ), - stagesReport, - countersSnapshot, - resultsReport - ); - context.writeReports( - id(), - TaskReport.buildTaskReports( - new MSQTaskReport(id(), taskReportPayload), - new TaskContextReport(id(), task.getContext()) - ) - ); - } - catch (Throwable e) { - log.warn(e, "Error encountered while writing task report. Skipping."); - } + stagePhaseMap = Collections.emptyMap(); + } - if (taskStateForReport == TaskState.SUCCESS) { - return TaskStatus.success(id()); + stagesReport = makeStageReport( + queryDef, + stagePhaseMap, + stageRuntimesForLiveReports, + stageWorkerCountsForLiveReports, + stagePartitionCountsForLiveReports, + stageOutputChannelModesForLiveReports + ); } else { - // errorForReport is nonnull when taskStateForReport != SUCCESS. Use that message. - return TaskStatus.failure(id(), MSQFaultUtils.generateMessageWithErrorCode(errorForReport.getFault())); - } + stagesReport = null; + } + + final MSQTaskReportPayload taskReportPayload = new MSQTaskReportPayload( + makeStatusReport( + taskStateForReport, + errorForReport, + workerWarnings, + queryStartTime, + new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), + workerManager, + segmentLoadWaiter, + segmentReport + ), + stagesReport, + countersSnapshot, + null + ); + + queryListener.onQueryComplete(taskReportPayload); } /** @@ -644,105 +573,59 @@ public void addToKernelManipulationQueue(Consumer kernelC private QueryDefinition initializeQueryDefAndState(final Closer closer) { - final QueryContext queryContext = task.getQuerySpec().getQuery().context(); - if (isFaultToleranceEnabled) { - if (!queryContext.containsKey(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE)) { - // if context key not set, enable durableStorage automatically. - isDurableStorageEnabled = true; - } else { - // if context key is set, and durableStorage is turned on. - if (MultiStageQueryContext.isDurableStorageEnabled(queryContext)) { - isDurableStorageEnabled = true; - } else { - throw new MSQException( - UnknownFault.forMessage( - StringUtils.format( - "Context param[%s] cannot be explicitly set to false when context param[%s] is" - + " set to true. Either remove the context param[%s] or explicitly set it to true.", - MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, - MultiStageQueryContext.CTX_FAULT_TOLERANCE, - MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE - ))); - } - } - } else { - isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext); - } - - log.debug("Task [%s] durable storage mode is set to %s.", task.getId(), isDurableStorageEnabled); - log.debug("Task [%s] fault tolerance mode is set to %s.", task.getId(), isFaultToleranceEnabled); - this.selfDruidNode = context.selfNode(); - context.registerController(this, closer); - - this.netClient = new ExceptionWrappingWorkerClient(context.taskClientFor(this)); - closer.register(netClient::close); + this.netClient = new ExceptionWrappingWorkerClient(context.newWorkerClient()); + closer.register(netClient); final QueryDefinition queryDef = makeQueryDefinition( - id(), + queryId(), makeQueryControllerToolKit(), - task.getQuerySpec(), + querySpec, context.jsonMapper() ); - QueryValidator.validateQueryDef(queryDef); - queryDefRef.set(queryDef); - - final long maxParseExceptions = task.getQuerySpec().getQuery().context().getLong( - MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, - MSQWarnings.DEFAULT_MAX_PARSE_EXCEPTIONS_ALLOWED - ); - - ImmutableMap.Builder taskContextOverridesBuilder = ImmutableMap.builder(); - taskContextOverridesBuilder - .put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, isDurableStorageEnabled) - .put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, maxParseExceptions); - - if (!MSQControllerTask.isIngestion(task.getQuerySpec())) { - if (MSQControllerTask.writeResultsToDurableStorage(task.getQuerySpec())) { - taskContextOverridesBuilder.put( - MultiStageQueryContext.CTX_SELECT_DESTINATION, - MSQSelectDestination.DURABLESTORAGE.getName() - ); - } else { - // we need not pass the value 'TaskReport' to the worker since the worker impl does not do anything in such a case. - // but we are passing it anyway for completeness - taskContextOverridesBuilder.put( - MultiStageQueryContext.CTX_SELECT_DESTINATION, - MSQSelectDestination.TASKREPORT.getName() + if (log.isDebugEnabled()) { + try { + log.debug( + "Query[%s] definition: %s", + queryDef.getQueryId(), + context.jsonMapper().writerWithDefaultPrettyPrinter().writeValueAsString(queryDef) ); } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } - taskContextOverridesBuilder.put( - MultiStageQueryContext.CTX_IS_REINDEX, - MSQControllerTask.isReplaceInputDataSourceTask(task) - ); - - // propagate the controller's tags to the worker task for enhanced metrics reporting - Map tags = task.getContextValue(DruidMetrics.TAGS); - if (tags != null) { - taskContextOverridesBuilder.put(DruidMetrics.TAGS, tags); - } + QueryValidator.validateQueryDef(queryDef); + queryDefRef.set(queryDef); - this.workerTaskLauncher = new MSQWorkerTaskLauncher( - id(), - task.getDataSource(), - context, + queryKernelConfig = context.queryKernelConfig(querySpec, queryDef); + workerManager = context.newWorkerManager( + queryId, + querySpec, + queryKernelConfig, (failedTask, fault) -> { - if (isFaultToleranceEnabled && ControllerQueryKernel.isRetriableFault(fault)) { - addToKernelManipulationQueue((kernel) -> { + if (queryKernelConfig.isFaultTolerant() && ControllerQueryKernel.isRetriableFault(fault)) { + addToKernelManipulationQueue(kernel -> { addToRetryQueue(kernel, failedTask.getWorkerNumber(), fault); }); } else { throw new MSQException(fault); } - }, - taskContextOverridesBuilder.build(), - // 10 minutes +- 2 minutes jitter - TimeUnit.SECONDS.toMillis(600 + ThreadLocalRandom.current().nextInt(-4, 5) * 30L) + } ); + if (queryKernelConfig.isFaultTolerant() && !(workerManager instanceof RetryCapableWorkerManager)) { + // Not expected to happen, since all WorkerManager impls are currently retry-capable. Defensive check + // for future-proofing. + throw DruidException.defensive( + "Cannot run with fault tolerance since workerManager class[%s] does not support retrying", + workerManager.getClass().getName() + ); + } + + final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(querySpec.getQuery().context()); this.faultsExceededChecker = new FaultsExceededChecker( ImmutableMap.of(CannotParseExternalDataFault.CODE, maxParseExceptions) ); @@ -754,15 +637,14 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer) stageDefinition.getId().getStageNumber(), finalizeClusterStatisticsMergeMode( stageDefinition, - MultiStageQueryContext.getClusterStatisticsMergeMode(queryContext) + MultiStageQueryContext.getClusterStatisticsMergeMode(querySpec.getQuery().context()) ) ) ); - this.workerMemoryParameters = WorkerMemoryParameters.createProductionInstanceForController(context.injector()); this.workerSketchFetcher = new WorkerSketchFetcher( netClient, - workerTaskLauncher, - isFaultToleranceEnabled + workerManager, + queryKernelConfig.isFaultTolerant() ); closer.register(workerSketchFetcher::close); @@ -777,10 +659,14 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer) */ private void addToRetryQueue(ControllerQueryKernel kernel, int worker, MSQFault fault) { + // Blind cast to RetryCapableWorkerManager is safe, since we verified that workerManager is retry-capable + // when initially creating it. + final RetryCapableWorkerManager retryCapableWorkerManager = (RetryCapableWorkerManager) workerManager; + List retriableWorkOrders = kernel.getWorkInCaseWorkerEligibleForRetryElseThrow(worker, fault); - if (retriableWorkOrders.size() != 0) { + if (!retriableWorkOrders.isEmpty()) { log.info("Submitting worker[%s] for relaunch because of fault[%s]", worker, fault); - workerTaskLauncher.submitForRelaunch(worker); + retryCapableWorkerManager.submitForRelaunch(worker); workOrdersToRetry.compute(worker, (workerNumber, workOrders) -> { if (workOrders == null) { return new HashSet<>(retriableWorkOrders); @@ -790,11 +676,11 @@ private void addToRetryQueue(ControllerQueryKernel kernel, int worker, MSQFault } }); } else { - log.info( + log.debug( "Worker[%d] has no active workOrders that need relaunch therefore not relaunching", worker ); - workerTaskLauncher.reportFailedInactiveWorker(worker); + retryCapableWorkerManager.reportFailedInactiveWorker(worker); } } @@ -813,6 +699,11 @@ public void updatePartialKeyStatisticsInformation( addToKernelManipulationQueue( queryKernel -> { final StageId stageId = queryKernel.getStageId(stageNumber); + + if (queryKernel.isStageFinished(stageId)) { + return; + } + final PartialKeyStatisticsInformation partialKeyStatisticsInformation; try { @@ -835,19 +726,41 @@ public void updatePartialKeyStatisticsInformation( ); } + @Override + public void doneReadingInput(int stageNumber, int workerNumber) + { + addToKernelManipulationQueue( + queryKernel -> { + final StageId stageId = queryKernel.getStageId(stageNumber); + + if (queryKernel.isStageFinished(stageId)) { + return; + } + + queryKernel.setDoneReadingInputForStageAndWorker(stageId, workerNumber); + } + ); + } @Override public void workerError(MSQErrorReport errorReport) { - if (workerTaskLauncher.isTaskCanceledByController(errorReport.getTaskId()) || - !workerTaskLauncher.isTaskLatest(errorReport.getTaskId())) { - log.info("Ignoring task %s", errorReport.getTaskId()); - } else { - workerErrorRef.compareAndSet( - null, - mapQueryColumnNameToOutputColumnName(errorReport) - ); + if (queryKernelConfig.isFaultTolerant()) { + // Blind cast to RetryCapableWorkerManager in fault-tolerant mode is safe, since when fault-tolerance is + // enabled, we verify that workerManager is retry-capable when initially creating it. + final RetryCapableWorkerManager retryCapableWorkerManager = (RetryCapableWorkerManager) workerManager; + + if (retryCapableWorkerManager.isTaskCanceledByController(errorReport.getTaskId()) || + !retryCapableWorkerManager.isWorkerActive(errorReport.getTaskId())) { + log.debug( + "Ignoring error report for worker[%s] because it was intentionally shut down.", + errorReport.getTaskId() + ); + return; + } } + + workerErrorRef.compareAndSet(null, mapQueryColumnNameToOutputColumnName(errorReport)); } /** @@ -920,6 +833,11 @@ public void resultsComplete( addToKernelManipulationQueue( queryKernel -> { final StageId stageId = new StageId(queryId, stageNumber); + + if (queryKernel.isStageFinished(stageId)) { + return; + } + final Object convertedResultObject; try { convertedResultObject = context.jsonMapper().convertValue( @@ -936,7 +854,6 @@ public void resultsComplete( ); } - queryKernel.setResultsCompleteForStageAndWorker(stageId, workerNumber, convertedResultObject); } ); @@ -954,7 +871,7 @@ public TaskReport.ReportMap liveReports() return TaskReport.buildTaskReports( new MSQTaskReport( - id(), + queryId(), new MSQTaskReportPayload( makeStatusReport( TaskState.RUNNING, @@ -962,7 +879,7 @@ public TaskReport.ReportMap liveReports() workerWarnings, queryStartTime, queryStartTime == null ? -1L : new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), - workerTaskLauncher, + workerManager, segmentLoadWaiter, segmentReport ), @@ -971,7 +888,8 @@ public TaskReport.ReportMap liveReports() stagePhasesForLiveReports, stageRuntimesForLiveReports, stageWorkerCountsForLiveReports, - stagePartitionCountsForLiveReports + stagePartitionCountsForLiveReports, + stageOutputChannelModesForLiveReports ), makeCountersSnapshotForLiveReports(), null @@ -982,9 +900,9 @@ public TaskReport.ReportMap liveReports() /** * @param isStageOutputEmpty {@code true} if the stage output is empty, {@code false} if the stage output is non-empty, - * {@code null} for stages where cluster key statistics are not gathered or is incomplete. + * {@code null} for stages where cluster key statistics are not gathered or is incomplete. * - * @return the segments that will be generated by this job. Delegates to + * @return the segments that will be generated by this job. Delegates to * {@link #generateSegmentIdsWithShardSpecsForAppend} or {@link #generateSegmentIdsWithShardSpecsForReplace} as * appropriate. This is a potentially expensive call, since it requires calling Overlord APIs. * @@ -1014,7 +932,7 @@ private List generateSegmentIdsWithShardSpecs( destination, partitionBoundaries, keyReader, - MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(task.getQuerySpec().getQuery().getContext()), false), + MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(querySpec.getQuery().getContext()), false), isStageOutputEmpty ); } @@ -1024,7 +942,7 @@ private List generateSegmentIdsWithShardSpecs( * Used by {@link #generateSegmentIdsWithShardSpecs}. * * @param isStageOutputEmpty {@code true} if the stage output is empty, {@code false} if the stage output is non-empty, - * {@code null} for stages where cluster key statistics are not gathered or is incomplete. + * {@code null} for stages where cluster key statistics are not gathered or is incomplete. */ private List generateSegmentIdsWithShardSpecsForAppend( final DataSourceMSQDestination destination, @@ -1055,13 +973,13 @@ private List generateSegmentIdsWithShardSpecsForAppend( try { allocation = context.taskActionClient().submit( new SegmentAllocateAction( - task.getDataSource(), + destination.getDataSource(), timestamp, // Same granularity for queryGranularity, segmentGranularity because we don't have insight here // into what queryGranularity "actually" is. (It depends on what time floor function was used.) segmentGranularity, segmentGranularity, - id(), + queryId(), previousSegmentId, false, NumberedPartialShardSpec.instance(), @@ -1081,7 +999,7 @@ private List generateSegmentIdsWithShardSpecsForAppend( if (allocation == null) { throw new MSQException( new InsertCannotAllocateSegmentFault( - task.getDataSource(), + destination.getDataSource(), segmentGranularity.bucket(timestamp), null ) @@ -1095,7 +1013,7 @@ private List generateSegmentIdsWithShardSpecsForAppend( if (!IntervalUtils.isAligned(allocation.getInterval(), segmentGranularity)) { throw new MSQException( new InsertCannotAllocateSegmentFault( - task.getDataSource(), + destination.getDataSource(), segmentGranularity.bucket(timestamp), allocation.getInterval() ) @@ -1113,8 +1031,7 @@ private List generateSegmentIdsWithShardSpecsForAppend( * Used by {@link #generateSegmentIdsWithShardSpecs}. * * @param isStageOutputEmpty {@code true} if the stage output is empty, {@code false} if the stage output is non-empty, - * {@code null} for stages where cluster key statistics are not gathered or is incomplete. - * + * {@code null} for stages where cluster key statistics are not gathered or is incomplete. */ private List generateSegmentIdsWithShardSpecsForReplace( final DataSourceMSQDestination destination, @@ -1135,10 +1052,17 @@ private List generateSegmentIdsWithShardSpecsForReplace( final List shardColumns; final Pair, String> shardReasonPair; - shardReasonPair = computeShardColumns(signature, clusterBy, task.getQuerySpec().getColumnMappings(), mayHaveMultiValuedClusterByFields); + shardReasonPair = computeShardColumns( + signature, + clusterBy, + querySpec.getColumnMappings(), + mayHaveMultiValuedClusterByFields + ); + shardColumns = shardReasonPair.lhs; String reason = shardReasonPair.rhs; - log.info(StringUtils.format("ShardSpec chosen: %s", reason)); + log.info("ShardSpec chosen: %s", reason); + if (shardColumns.isEmpty()) { segmentReport = new MSQSegmentReport(NumberedShardSpec.class.getSimpleName(), reason); } else { @@ -1194,26 +1118,21 @@ private List generateSegmentIdsWithShardSpecsForReplace( shardSpec = new DimensionRangeShardSpec(shardColumns, start, end, segmentNumber, ranges.size()); } - retVal[partitionNumber] = new SegmentIdWithShardSpec(task.getDataSource(), interval, version, shardSpec); + retVal[partitionNumber] = new SegmentIdWithShardSpec(destination.getDataSource(), interval, version, shardSpec); } } return Arrays.asList(retVal); } - /** - * Returns a complete list of task ids, ordered by worker number. The Nth task has worker number N. - *

- * If the currently-running set of tasks is incomplete, returns an absent Optional. - */ @Override public List getTaskIds() { - if (workerTaskLauncher == null) { + if (workerManager == null) { return Collections.emptyList(); } - return workerTaskLauncher.getActiveTasks(); + return workerManager.getWorkerIds(); } @SuppressWarnings({"unchecked", "rawtypes"}) @@ -1225,7 +1144,7 @@ private Int2ObjectMap makeWorkerFactoryInfosForStage( @Nullable final List segmentsToGenerate ) { - if (MSQControllerTask.isIngestion(task.getQuerySpec()) && + if (MSQControllerTask.isIngestion(querySpec) && stageNumber == queryDef.getFinalStageDefinition().getStageNumber()) { // noinspection unchecked,rawtypes return (Int2ObjectMap) makeSegmentGeneratorWorkerFactoryInfos(workerInputs, segmentsToGenerate); @@ -1247,94 +1166,6 @@ private QueryKit makeQueryControllerToolKit() return new MultiQueryKit(kitMap); } - private DataSegmentTimelineView makeDataSegmentTimelineView() - { - final SegmentSource includeSegmentSource = MultiStageQueryContext.getSegmentSources( - task.getQuerySpec() - .getQuery() - .context() - ); - - final boolean includeRealtime = SegmentSource.shouldQueryRealtimeServers(includeSegmentSource); - - return (dataSource, intervals) -> { - final Iterable realtimeAndHistoricalSegments; - - // Fetch the realtime segments and segments loaded on the historical. Do this first so that we don't miss any - // segment if they get handed off between the two calls. Segments loaded on historicals are deduplicated below, - // since we are only interested in realtime segments for now. - if (includeRealtime) { - realtimeAndHistoricalSegments = context.coordinatorClient().fetchServerViewSegments(dataSource, intervals); - } else { - realtimeAndHistoricalSegments = ImmutableList.of(); - } - - // Fetch all published, used segments (all non-realtime segments) from the metadata store. - // If the task is operating with a REPLACE lock, - // any segment created after the lock was acquired for its interval will not be considered. - final Collection publishedUsedSegments; - try { - // Additional check as the task action does not accept empty intervals - if (intervals.isEmpty()) { - publishedUsedSegments = Collections.emptySet(); - } else { - publishedUsedSegments = context.taskActionClient().submit(new RetrieveUsedSegmentsAction( - dataSource, - intervals - )); - } - } - catch (IOException e) { - throw new MSQException(e, UnknownFault.forException(e)); - } - - int realtimeCount = 0; - - // Deduplicate segments, giving preference to published used segments. - // We do this so that if any segments have been handed off in between the two metadata calls above, - // we directly fetch it from deep storage. - Set unifiedSegmentView = new HashSet<>(publishedUsedSegments); - - // Iterate over the realtime segments and segments loaded on the historical - for (ImmutableSegmentLoadInfo segmentLoadInfo : realtimeAndHistoricalSegments) { - ImmutableSet servers = segmentLoadInfo.getServers(); - // Filter out only realtime servers. We don't want to query historicals for now, but we can in the future. - // This check can be modified then. - Set realtimeServerMetadata - = servers.stream() - .filter(druidServerMetadata -> includeSegmentSource.getUsedServerTypes() - .contains(druidServerMetadata.getType()) - ) - .collect(Collectors.toSet()); - if (!realtimeServerMetadata.isEmpty()) { - realtimeCount += 1; - DataSegmentWithLocation dataSegmentWithLocation = new DataSegmentWithLocation( - segmentLoadInfo.getSegment(), - realtimeServerMetadata - ); - unifiedSegmentView.add(dataSegmentWithLocation); - } else { - // We don't have any segments of the required segment source, ignore the segment - } - } - - if (includeRealtime) { - log.info( - "Fetched total [%d] segments from coordinator: [%d] from metadata stoure, [%d] from server view", - unifiedSegmentView.size(), - publishedUsedSegments.size(), - realtimeCount - ); - } - - if (unifiedSegmentView.isEmpty()) { - return Optional.empty(); - } else { - return Optional.of(SegmentTimeline.forSegments(unifiedSegmentView)); - } - }; - } - private Int2ObjectMap> makeSegmentGeneratorWorkerFactoryInfos( final WorkerInputs workerInputs, final List segmentsToGenerate @@ -1369,75 +1200,59 @@ private Int2ObjectMap> makeSegmentGeneratorWorkerFa * * @param queryKernel * @param contactFn - * @param workers set of workers to contact - * @param successCallBack After contacting all the tasks, a custom callback is invoked in the main thread for each successfully contacted task. - * @param retryOnFailure If true, after contacting all the tasks, adds this worker to retry queue in the main thread. - * If false, cancel all the futures and propagate the exception to the caller. + * @param workers set of workers to contact + * @param successFn After contacting all the tasks, a custom callback is invoked in the main thread for each successfully contacted task. + * @param retryOnFailure If true, after contacting all the tasks, adds this worker to retry queue in the main thread. + * If false, cancel all the futures and propagate the exception to the caller. */ private void contactWorkersForStage( final ControllerQueryKernel queryKernel, - final TaskContactFn contactFn, final IntSet workers, - final TaskContactSuccess successCallBack, + final TaskContactFn contactFn, + final TaskContactSuccess successFn, final boolean retryOnFailure ) { - final List taskIds = getTaskIds(); - final List> taskFutures = new ArrayList<>(workers.size()); + // Sorted copy of target worker numbers to ensure consistent iteration order. + final List workersCopy = Ordering.natural().sortedCopy(workers); + final List workerIds = getTaskIds(); + final List> workerFutures = new ArrayList<>(workersCopy.size()); try { - workerTaskLauncher.waitUntilWorkersReady(workers); + workerManager.waitForWorkers(workers); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); throw new RuntimeException(e); } - Set failedCalls = ConcurrentHashMap.newKeySet(); - Set successfulCalls = ConcurrentHashMap.newKeySet(); - - for (int workerNumber : workers) { - final String taskId = taskIds.get(workerNumber); - SettableFuture settableFuture = SettableFuture.create(); - ListenableFuture apiFuture = contactFn.contactTask(netClient, taskId, workerNumber); - Futures.addCallback(apiFuture, new FutureCallback() - { - @Override - public void onSuccess(@Nullable Void result) - { - successfulCalls.add(taskId); - settableFuture.set(true); - } - - @Override - public void onFailure(Throwable t) - { - if (retryOnFailure) { - log.info( - t, - "Detected failure while contacting task[%s]. Initiating relaunch of worker[%d] if applicable", - taskId, - MSQTasks.workerFromTaskId(taskId) - ); - failedCalls.add(taskId); - settableFuture.set(false); - } else { - settableFuture.setException(t); - } - } - }, MoreExecutors.directExecutor()); - - taskFutures.add(settableFuture); + for (final int workerNumber : workersCopy) { + workerFutures.add(contactFn.contactTask(netClient, workerIds.get(workerNumber), workerNumber)); } - FutureUtils.getUnchecked(MSQFutureUtils.allAsList(taskFutures, true), true); + final List> workerResults = + FutureUtils.getUnchecked(FutureUtils.coalesce(workerFutures), true); - for (String taskId : successfulCalls) { - successCallBack.onSuccess(taskId); - } + for (int i = 0; i < workerResults.size(); i++) { + final int workerNumber = workersCopy.get(i); + final String workerId = workerIds.get(workerNumber); + final Either workerResult = workerResults.get(i); + + if (workerResult.isValue()) { + successFn.onSuccess(workerId, workerNumber); + } else if (retryOnFailure) { + // Possibly retryable failure. + log.info( + workerResult.error(), + "Detected failure while contacting task[%s]. Initiating relaunch of worker[%d] if applicable", + workerId, + workerNumber + ); - if (retryOnFailure) { - for (String taskId : failedCalls) { - addToRetryQueue(queryKernel, MSQTasks.workerFromTaskId(taskId), new WorkerRpcFailedFault(taskId)); + addToRetryQueue(queryKernel, workerNumber, new WorkerRpcFailedFault(workerId)); + } else { + // Nonretryable failure. + throw new RuntimeException(workerResult.error()); } } } @@ -1462,10 +1277,12 @@ private void startWorkForStage( queryKernel.startStage(stageId); contactWorkersForStage( queryKernel, + workOrders.keySet(), (netClient, taskId, workerNumber) -> ( - netClient.postWorkOrder(taskId, workOrders.get(workerNumber))), workOrders.keySet(), - (taskId) -> queryKernel.workOrdersSentForWorker(stageId, MSQTasks.workerFromTaskId(taskId)), - isFaultToleranceEnabled + netClient.postWorkOrder(taskId, workOrders.get(workerNumber))), + (workerId, workerNumber) -> + queryKernel.workOrdersSentForWorker(stageId, workerNumber), + queryKernelConfig.isFaultTolerant() ); } @@ -1481,14 +1298,12 @@ private void postResultPartitionBoundariesForStage( contactWorkersForStage( queryKernel, - (netClient, taskId, workerNumber) -> netClient.postResultPartitionBoundaries( - taskId, - stageId, - resultPartitionBoundaries - ), workers, - (taskId) -> queryKernel.partitionBoundariesSentForWorker(stageId, MSQTasks.workerFromTaskId(taskId)), - isFaultToleranceEnabled + (netClient, workerId, workerNumber) -> + netClient.postResultPartitionBoundaries(workerId, stageId, resultPartitionBoundaries), + (workerId, workerNumber) -> + queryKernel.partitionBoundariesSentForWorker(stageId, workerNumber), + queryKernelConfig.isFaultTolerant() ); } @@ -1499,11 +1314,11 @@ private void postResultPartitionBoundariesForStage( private void publishAllSegments(final Set segments) throws IOException { final DataSourceMSQDestination destination = - (DataSourceMSQDestination) task.getQuerySpec().getDestination(); + (DataSourceMSQDestination) querySpec.getDestination(); final Set segmentsWithTombstones = new HashSet<>(segments); int numTombstones = 0; final TaskLockType taskLockType = MultiStageQueryContext.validateAndGetTaskLockType( - QueryContext.of(task.getQuerySpec().getQuery().getContext()), + QueryContext.of(querySpec.getQuery().getContext()), destination.isReplaceTimeChunks() ); @@ -1516,7 +1331,7 @@ private void publishAllSegments(final Set segments) throws IOExcept Set tombstones = tombstoneHelper.computeTombstoneSegmentsForReplace( intervalsToDrop, destination.getReplaceTimeChunks(), - task.getDataSource(), + destination.getDataSource(), destination.getSegmentGranularity(), Limits.MAX_PARTITION_BUCKETS ); @@ -1537,15 +1352,15 @@ private void publishAllSegments(final Set segments) throws IOExcept // This should not need a segment load wait as segments are marked as unused immediately. for (final Interval interval : intervalsToDrop) { context.taskActionClient() - .submit(new MarkSegmentsAsUnusedAction(task.getDataSource(), interval)); + .submit(new MarkSegmentsAsUnusedAction(destination.getDataSource(), interval)); } } else { - if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { + if (MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), - task.getId(), - task.getDataSource(), + queryId, + destination.getDataSource(), segmentsWithTombstones, true ); @@ -1556,12 +1371,12 @@ private void publishAllSegments(final Set segments) throws IOExcept ); } } else if (!segments.isEmpty()) { - if (MultiStageQueryContext.shouldWaitForSegmentLoad(task.getQuerySpec().getQuery().context())) { + if (MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context())) { segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), - task.getId(), - task.getDataSource(), + queryId, + destination.getDataSource(), segments, true ); @@ -1573,9 +1388,9 @@ private void publishAllSegments(final Set segments) throws IOExcept ); } - task.emitMetric(context.emitter(), "ingest/tombstones/count", numTombstones); + context.emitMetric("ingest/tombstones/count", numTombstones); // Include tombstones in the reported segments count - task.emitMetric(context.emitter(), "ingest/segments/count", segmentsWithTombstones.size()); + context.emitMetric("ingest/segments/count", segmentsWithTombstones.size()); } private static TaskAction createAppendAction( @@ -1614,7 +1429,7 @@ private List findIntervalsToDrop(final Set publishedSegme { // Safe to cast because publishAllSegments is only called for dataSource destinations. final DataSourceMSQDestination destination = - (DataSourceMSQDestination) task.getQuerySpec().getDestination(); + (DataSourceMSQDestination) querySpec.getDestination(); final List replaceIntervals = new ArrayList<>(JodaUtils.condenseIntervals(destination.getReplaceTimeChunks())); final List publishIntervals = @@ -1671,80 +1486,6 @@ private CounterSnapshotsTree getFinalCountersSnapshot(@Nullable final Controller } } - @Nullable - private Yielder getFinalResultsYielder( - final QueryDefinition queryDef, - final ControllerQueryKernel queryKernel - ) - { - if (queryKernel.isSuccess() && isInlineResults(task.getQuerySpec())) { - final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); - final List taskIds = getTaskIds(); - final Closer closer = Closer.create(); - - final ListeningExecutorService resultReaderExec = - MoreExecutors.listeningDecorator(Execs.singleThreaded("result-reader-%d")); - closer.register(resultReaderExec::shutdownNow); - - final InputChannelFactory inputChannelFactory; - - if (isDurableStorageEnabled || MSQControllerTask.writeResultsToDurableStorage(task.getQuerySpec())) { - inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( - id(), - MSQTasks.makeStorageConnector( - context.injector()), - closer, - MSQControllerTask.writeResultsToDurableStorage(task.getQuerySpec()) - ); - } else { - inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds); - } - - final InputChannels inputChannels = new InputChannelsImpl( - queryDef, - queryKernel.getResultPartitionsForStage(finalStageId), - inputChannelFactory, - () -> ArenaMemoryAllocator.createOnHeap(5_000_000), - new FrameProcessorExecutor(resultReaderExec), - null - ); - - return Yielders.each( - Sequences.concat( - StreamSupport.stream(queryKernel.getResultPartitionsForStage(finalStageId).spliterator(), false) - .map( - readablePartition -> { - try { - return new FrameChannelSequence( - inputChannels.openChannel( - new StagePartition( - queryKernel.getStageDefinition(finalStageId).getId(), - readablePartition.getPartitionNumber() - ) - ) - ); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - ).collect(Collectors.toList()) - ).flatMap( - frame -> - SqlStatementResourceHelper.getResultSequence( - task, - queryDef.getFinalStageDefinition(), - frame, - context.jsonMapper() - ) - ) - .withBaggage(resultReaderExec::shutdownNow) - ); - } else { - return null; - } - } - private void handleQueryResults( final QueryDefinition queryDef, final ControllerQueryKernel queryKernel @@ -1753,22 +1494,21 @@ private void handleQueryResults( if (!queryKernel.isSuccess()) { return; } - if (MSQControllerTask.isIngestion(task.getQuerySpec())) { + if (MSQControllerTask.isIngestion(querySpec)) { // Publish segments if needed. final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); - //noinspection unchecked @SuppressWarnings("unchecked") Set segments = (Set) queryKernel.getResultObjectForStage(finalStageId); - boolean storeCompactionState = QueryContext.of(task.getQuerySpec().getQuery().getContext()) + boolean storeCompactionState = QueryContext.of(querySpec.getQuery().getContext()) .getBoolean( Tasks.STORE_COMPACTION_STATE_KEY, Tasks.DEFAULT_STORE_COMPACTION_STATE ); if (!segments.isEmpty() && storeCompactionState) { - DataSourceMSQDestination destination = (DataSourceMSQDestination) task.getQuerySpec().getDestination(); + DataSourceMSQDestination destination = (DataSourceMSQDestination) querySpec.getDestination(); if (!destination.isReplaceTimeChunks()) { // Store compaction state only for replace queries. log.warn( @@ -1782,7 +1522,7 @@ private void handleQueryResults( ShardSpec shardSpec = segments.stream().findFirst().get().getShardSpec(); Function, Set> compactionStateAnnotateFunction = addCompactionStateToSegments( - task(), + querySpec, context.jsonMapper(), dataSchema, shardSpec, @@ -1793,9 +1533,28 @@ private void handleQueryResults( } log.info("Query [%s] publishing %d segments.", queryDef.getQueryId(), segments.size()); publishAllSegments(segments); - } else if (MSQControllerTask.isExport(task.getQuerySpec())) { + } else if (MSQControllerTask.isExport(querySpec)) { + // Write manifest file. + ExportMSQDestination destination = (ExportMSQDestination) querySpec.getDestination(); + ExportMetadataManager exportMetadataManager = new ExportMetadataManager(destination.getExportStorageProvider()); + + final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); + //noinspection unchecked + + + Object resultObjectForStage = queryKernel.getResultObjectForStage(finalStageId); + if (!(resultObjectForStage instanceof List)) { + // This might occur if all workers are running on an older version. We are not able to write a manifest file in this case. + log.warn("Was unable to create manifest file due to "); + return; + } + @SuppressWarnings("unchecked") + List exportedFiles = (List) queryKernel.getResultObjectForStage(finalStageId); + log.info("Query [%s] exported %d files.", queryDef.getQueryId(), exportedFiles.size()); + exportMetadataManager.writeMetadata(exportedFiles); + } else if (MSQControllerTask.isExport(querySpec)) { // Write manifest file. - ExportMSQDestination destination = (ExportMSQDestination) task.getQuerySpec().getDestination(); + ExportMSQDestination destination = (ExportMSQDestination) querySpec.getDestination(); ExportMetadataManager exportMetadataManager = new ExportMetadataManager(destination.getExportStorageProvider()); final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); @@ -1816,14 +1575,14 @@ private void handleQueryResults( } private static Function, Set> addCompactionStateToSegments( - MSQControllerTask task, + MSQSpec querySpec, ObjectMapper jsonMapper, DataSchema dataSchema, ShardSpec shardSpec, String queryId ) { - final MSQTuningConfig tuningConfig = task.getQuerySpec().getTuningConfig(); + final MSQTuningConfig tuningConfig = querySpec.getTuningConfig(); PartitionsSpec partitionSpec; if (Objects.equals(shardSpec.getType(), ShardSpec.Type.RANGE)) { @@ -1848,7 +1607,7 @@ private static Function, Set> addCompactionStateTo ))); } - Granularity segmentGranularity = ((DataSourceMSQDestination) task.getQuerySpec().getDestination()) + Granularity segmentGranularity = ((DataSourceMSQDestination) querySpec.getDestination()) .getSegmentGranularity(); GranularitySpec granularitySpec = new UniformGranularitySpec( @@ -1895,15 +1654,15 @@ private static Function, Set> addCompactionStateTo */ private void cleanUpDurableStorageIfNeeded() { - if (isDurableStorageEnabled) { - final String controllerDirName = DurableStorageUtils.getControllerDirectory(task.getId()); + if (queryKernelConfig != null && queryKernelConfig.isDurableStorage()) { + final String controllerDirName = DurableStorageUtils.getControllerDirectory(queryId()); try { // Delete all temporary files as a failsafe MSQTasks.makeStorageConnector(context.injector()).deleteRecursively(controllerDirName); } catch (Exception e) { // If an error is thrown while cleaning up a file, log it and try to continue with the cleanup - log.warn(e, "Error while cleaning up temporary files at path %s", controllerDirName); + log.warn(e, "Error while cleaning up temporary files at path[%s]. Skipping.", controllerDirName); } } } @@ -1945,10 +1704,9 @@ private static QueryDefinition makeQueryDefinition( queryToPlan = querySpec.getQuery(); } } else { - shuffleSpecFactory = querySpec.getDestination() - .getShuffleSpecFactory( - MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context()) - ); + shuffleSpecFactory = + querySpec.getDestination() + .getShuffleSpecFactory(MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context())); queryToPlan = querySpec.getQuery(); } @@ -1992,7 +1750,7 @@ private static QueryDefinition makeQueryDefinition( // Add all query stages. // Set shuffleCheckHasMultipleValues on the stage that serves as input to the final segment-generation stage. - final QueryDefinitionBuilder builder = QueryDefinition.builder(); + final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId); for (final StageDefinition stageDef : queryDef.getStageDefinitions()) { if (stageDef.equals(finalShuffleStageDef)) { @@ -2004,7 +1762,7 @@ private static QueryDefinition makeQueryDefinition( // Then, add a segment-generation stage. final DataSchema dataSchema = - generateDataSchema(querySpec, querySignature, queryClusterBy, columnMappings, jsonMapper); + makeDataSchemaForIngestion(querySpec, querySignature, queryClusterBy, columnMappings, jsonMapper); builder.add( StageDefinition.builder(queryDef.getNextStageNumber()) @@ -2027,7 +1785,7 @@ private static QueryDefinition makeQueryDefinition( // attaching new query results stage if the final stage does sort during shuffle so that results are ordered. StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition(); if (finalShuffleStageDef.doesSortDuringShuffle()) { - final QueryDefinitionBuilder builder = QueryDefinition.builder(); + final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId); builder.addAll(queryDef); builder.add(StageDefinition.builder(queryDef.getNextStageNumber()) .inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber())) @@ -2063,9 +1821,8 @@ private static QueryDefinition makeQueryDefinition( .build(e, "Exception occurred while connecting to export destination."); } - final ResultFormat resultFormat = exportMSQDestination.getResultFormat(); - final QueryDefinitionBuilder builder = QueryDefinition.builder(); + final QueryDefinitionBuilder builder = QueryDefinition.builder(queryId); builder.addAll(queryDef); builder.add(StageDefinition.builder(queryDef.getNextStageNumber()) .inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber())) @@ -2085,7 +1842,12 @@ private static QueryDefinition makeQueryDefinition( } } - private static DataSchema generateDataSchema( + private static String getDataSourceForIngestion(final MSQSpec querySpec) + { + return ((DataSourceMSQDestination) querySpec.getDestination()).getDataSource(); + } + + private static DataSchema makeDataSchemaForIngestion( MSQSpec querySpec, RowSignature querySignature, ClusterBy queryClusterBy, @@ -2187,19 +1949,6 @@ private static boolean isRollupQuery(Query query) && !query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true); } - private static boolean isInlineResults(final MSQSpec querySpec) - { - return querySpec.getDestination() instanceof TaskReportMSQDestination - || querySpec.getDestination() instanceof DurableStorageMSQDestination; - } - - private static boolean isTimeBucketedIngestion(final MSQSpec querySpec) - { - return MSQControllerTask.isIngestion(querySpec) - && !((DataSourceMSQDestination) querySpec.getDestination()).getSegmentGranularity() - .equals(Granularities.ALL); - } - /** * Compute shard columns for {@link DimensionRangeShardSpec}. Returns an empty list if range-based sharding * is not applicable. @@ -2477,7 +2226,8 @@ private static MSQStagesReport makeStageReport( final Map stagePhaseMap, final Map stageRuntimeMap, final Map stageWorkerCountMap, - final Map stagePartitionCountMap + final Map stagePartitionCountMap, + final Map stageOutputChannelModeMap ) { return MSQStagesReport.create( @@ -2485,35 +2235,8 @@ private static MSQStagesReport makeStageReport( ImmutableMap.copyOf(stagePhaseMap), copyOfStageRuntimesEndingAtCurrentTime(stageRuntimeMap), stageWorkerCountMap, - stagePartitionCountMap - ); - } - - private static MSQResultsReport makeResultsTaskReport( - final QueryDefinition queryDef, - final Yielder resultsYielder, - final ColumnMappings columnMappings, - @Nullable final List sqlTypeNames, - final MSQSelectDestination selectDestination - ) - { - final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature(); - final ImmutableList.Builder mappedSignature = ImmutableList.builder(); - - for (final ColumnMapping mapping : columnMappings.getMappings()) { - mappedSignature.add( - new MSQResultsReport.ColumnAndType( - mapping.getOutputColumn(), - querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) - ) - ); - } - - return MSQResultsReport.createReportAndLimitRowsIfNeeded( - mappedSignature.build(), - sqlTypeNames, - resultsYielder, - selectDestination + stagePartitionCountMap, + stageOutputChannelModeMap ); } @@ -2523,17 +2246,17 @@ private static MSQStatusReport makeStatusReport( final Queue errorReports, @Nullable final DateTime queryStartTime, final long queryDuration, - MSQWorkerTaskLauncher taskLauncher, + final WorkerManager taskLauncher, final SegmentLoadStatusFetcher segmentLoadWaiter, @Nullable MSQSegmentReport msqSegmentReport ) { int pendingTasks = -1; int runningTasks = 1; - Map> workerStatsMap = new HashMap<>(); + Map> workerStatsMap = new HashMap<>(); if (taskLauncher != null) { - WorkerCount workerTaskCount = taskLauncher.getWorkerTaskCount(); + WorkerCount workerTaskCount = taskLauncher.getWorkerCount(); pendingTasks = workerTaskCount.getPendingWorkerCount(); runningTasks = workerTaskCount.getRunningWorkerCount() + 1; // To account for controller. workerStatsMap = taskLauncher.getWorkerStats(); @@ -2557,15 +2280,15 @@ private static MSQStatusReport makeStatusReport( ); } - private static InputSpecSlicerFactory makeInputSpecSlicerFactory(final DataSegmentTimelineView timelineView) + private static InputSpecSlicerFactory makeInputSpecSlicerFactory(final InputSpecSlicer tableInputSpecSlicer) { - return stagePartitionsMap -> new MapInputSpecSlicer( + return (stagePartitionsMap, stageOutputChannelModeMap) -> new MapInputSpecSlicer( ImmutableMap., InputSpecSlicer>builder() - .put(StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap)) + .put(StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap, stageOutputChannelModeMap)) .put(ExternalInputSpec.class, new ExternalInputSpecSlicer()) .put(InlineInputSpec.class, new InlineInputSpecSlicer()) .put(LookupInputSpec.class, new LookupInputSpecSlicer()) - .put(TableInputSpec.class, new TableInputSpecSlicer(timelineView)) + .put(TableInputSpec.class, tableInputSpecSlicer) .build() ); } @@ -2679,11 +2402,12 @@ private class RunQueryUntilDone { private final QueryDefinition queryDef; private final InputSpecSlicerFactory inputSpecSlicerFactory; + private final QueryListener queryListener; private final Closer closer; private final ControllerQueryKernel queryKernel; /** - * Return value of {@link MSQWorkerTaskLauncher#start()}. Set by {@link #startTaskLauncher()}. + * Return value of {@link WorkerManager#start()}. Set by {@link #startTaskLauncher()}. */ private ListenableFuture workerTaskLauncherFuture; @@ -2694,20 +2418,26 @@ private class RunQueryUntilDone */ private List segmentsToGenerate; + /** + * Future that resolves when the reader from {@link #startQueryResultsReader()} finishes. Prior to that method + * being called, this future is null. + */ + @Nullable + private ListenableFuture queryResultsReaderFuture; + public RunQueryUntilDone( final QueryDefinition queryDef, + final ControllerQueryKernelConfig queryKernelConfig, final InputSpecSlicerFactory inputSpecSlicerFactory, + final QueryListener queryListener, final Closer closer ) { this.queryDef = queryDef; this.inputSpecSlicerFactory = inputSpecSlicerFactory; + this.queryListener = queryListener; this.closer = closer; - this.queryKernel = new ControllerQueryKernel( - queryDef, - workerMemoryParameters.getPartitionStatisticsMaxRetainedBytes(), - isFaultToleranceEnabled - ); + this.queryKernel = new ControllerQueryKernel(queryDef, queryKernelConfig); } /** @@ -2717,15 +2447,20 @@ private Pair> run() throws IOExceptio { startTaskLauncher(); + boolean runAgain; while (!queryKernel.isDone()) { startStages(); fetchStatsFromWorkers(); sendPartitionBoundaries(); updateLiveReportMaps(); - cleanUpEffectivelyFinishedStages(); + readQueryResults(); + runAgain = cleanUpEffectivelyFinishedStages(); retryFailedTasks(); checkForErrorsInSketchFetcher(); - runKernelCommands(); + + if (!runAgain) { + runKernelCommands(); + } } if (!queryKernel.isSuccess()) { @@ -2745,11 +2480,21 @@ private void checkForErrorsInSketchFetcher() } } + /** + * Read query results, if appropriate and possible. Returns true if something was read. + */ + private void readQueryResults() + { + // Open query results channel, if appropriate. + if (queryListener.readResults() && queryKernel.canReadQueryResults() && queryResultsReaderFuture == null) { + startQueryResultsReader(); + } + } private void retryFailedTasks() throws InterruptedException { // if no work orders to rety skip - if (workOrdersToRetry.size() == 0) { + if (workOrdersToRetry.isEmpty()) { return; } Set workersNeedToBeFullyStarted = new HashSet<>(); @@ -2765,7 +2510,7 @@ private void retryFailedTasks() throws InterruptedException new StageId(queryDef.getQueryId(), workOrder.getStageNumber()), (stageId, workOrders) -> { if (workOrders == null) { - workOrders = new HashMap(); + workOrders = new HashMap<>(); } workOrders.put(workerStages.getKey(), workOrder); return workOrders; @@ -2775,27 +2520,23 @@ private void retryFailedTasks() throws InterruptedException } // wait till the workers identified above are fully ready - workerTaskLauncher.waitUntilWorkersReady(workersNeedToBeFullyStarted); + workerManager.waitForWorkers(workersNeedToBeFullyStarted); for (Map.Entry> stageWorkOrders : stageWorkerOrders.entrySet()) { - contactWorkersForStage( queryKernel, - (netClient, taskId, workerNumber) -> netClient.postWorkOrder( - taskId, - stageWorkOrders.getValue().get(workerNumber) - ), new IntArraySet(stageWorkOrders.getValue().keySet()), - (taskId) -> { - int workerNumber = MSQTasks.workerFromTaskId(taskId); + (netClient, workerId, workerNumber) -> + netClient.postWorkOrder(workerId, stageWorkOrders.getValue().get(workerNumber)), + (workerId, workerNumber) -> { queryKernel.workOrdersSentForWorker(stageWorkOrders.getKey(), workerNumber); // remove successfully contacted workOrders from workOrdersToRetry workOrdersToRetry.compute(workerNumber, (task, workOrderSet) -> { - if (workOrderSet == null || workOrderSet.size() == 0 || !workOrderSet.remove(stageWorkOrders.getValue() - .get( - workerNumber))) { - throw new ISE("Worker[%d] orders not found", workerNumber); + if (workOrderSet == null + || workOrderSet.size() == 0 + || !workOrderSet.remove(stageWorkOrders.getValue().get(workerNumber))) { + throw new ISE("Worker[%s] with number[%d] orders not found", workerId, workerNumber); } if (workOrderSet.size() == 0) { return null; @@ -2803,7 +2544,7 @@ private void retryFailedTasks() throws InterruptedException return workOrderSet; }); }, - isFaultToleranceEnabled + queryKernelConfig.isFaultTolerant() ); } } @@ -2827,16 +2568,16 @@ private void runKernelCommands() throws InterruptedException } /** - * Start up the {@link MSQWorkerTaskLauncher}, such that later on it can be used to launch new tasks - * via {@link MSQWorkerTaskLauncher#launchTasksIfNeeded}. + * Start up the {@link WorkerManager}, such that later on it can be used to launch new tasks + * via {@link WorkerManager#launchWorkersIfNeeded}. */ private void startTaskLauncher() { // Start tasks. log.debug("Query [%s] starting task launcher.", queryDef.getQueryId()); - workerTaskLauncherFuture = workerTaskLauncher.start(); - closer.register(() -> workerTaskLauncher.stop(true)); + workerTaskLauncherFuture = workerManager.start(); + closer.register(() -> workerManager.stop(true)); workerTaskLauncherFuture.addListener( () -> @@ -2857,7 +2598,7 @@ private void fetchStatsFromWorkers() for (Map.Entry> stageToWorker : queryKernel.getStagesAndWorkersToFetchClusterStats() .entrySet()) { - List allTasks = workerTaskLauncher.getActiveTasks(); + List allTasks = workerManager.getWorkerIds(); Set tasks = stageToWorker.getValue().stream().map(allTasks::get).collect(Collectors.toSet()); ClusterStatisticsMergeMode clusterStatisticsMergeMode = stageToStatsMergingMode.get(stageToWorker.getKey() @@ -2881,7 +2622,7 @@ private void submitParallelMergeRequests(StageId stageId, Set tasks) // eagerly change state of workers whose state is being fetched so that we do not keep on queuing fetch requests. queryKernel.startFetchingStatsFromWorker( stageId, - tasks.stream().map(MSQTasks::workerFromTaskId).collect(Collectors.toSet()) + tasks.stream().map(workerManager::getWorkerNumber).collect(Collectors.toSet()) ); workerSketchFetcher.inMemoryFullSketchMerging(ControllerImpl.this::addToKernelManipulationQueue, stageId, tasks, @@ -2896,13 +2637,14 @@ private void submitSequentialMergeFetchRequests(StageId stageId, Set tas queryKernel.startFetchingStatsFromWorker( stageId, tasks.stream() - .map(MSQTasks::workerFromTaskId) + .map(workerManager::getWorkerNumber) .collect(Collectors.toSet()) ); workerSketchFetcher.sequentialTimeChunkMerging( ControllerImpl.this::addToKernelManipulationQueue, queryKernel.getCompleteKeyStatisticsInformation(stageId), - stageId, tasks, + stageId, + tasks, ControllerImpl.this::addToRetryQueue ); } @@ -2914,69 +2656,88 @@ private void submitSequentialMergeFetchRequests(StageId stageId, Set tas private void startStages() throws IOException, InterruptedException { final long maxInputBytesPerWorker = - MultiStageQueryContext.getMaxInputBytesPerWorker(task.getQuerySpec().getQuery().context()); + MultiStageQueryContext.getMaxInputBytesPerWorker(querySpec.getQuery().context()); logKernelStatus(queryDef.getQueryId(), queryKernel); - final List newStageIds = queryKernel.createAndGetNewStageIds( - inputSpecSlicerFactory, - task.getQuerySpec().getAssignmentStrategy(), - maxInputBytesPerWorker - ); - for (final StageId stageId : newStageIds) { - - // Allocate segments, if this is the final stage of an ingestion. - if (MSQControllerTask.isIngestion(task.getQuerySpec()) - && stageId.getStageNumber() == queryDef.getFinalStageDefinition().getStageNumber()) { - // We need to find the shuffle details (like partition ranges) to generate segments. Generally this is - // going to correspond to the stage immediately prior to the final segment-generator stage. - int shuffleStageNumber = Iterables.getOnlyElement(queryDef.getFinalStageDefinition().getInputStageNumbers()); - - // The following logic assumes that output of all the stages without a shuffle retain the partition boundaries - // of the input to that stage. This may not always be the case. For example: GROUP BY queries without an - // ORDER BY clause. This works for QueryKit generated queries up until now, but it should be reworked as it - // might not always be the case. - while (!queryDef.getStageDefinition(shuffleStageNumber).doesShuffle()) { - shuffleStageNumber = - Iterables.getOnlyElement(queryDef.getStageDefinition(shuffleStageNumber).getInputStageNumbers()); - } + List newStageIds; + + do { + newStageIds = queryKernel.createAndGetNewStageIds( + inputSpecSlicerFactory, + querySpec.getAssignmentStrategy(), + maxInputBytesPerWorker + ); - final StageId shuffleStageId = new StageId(queryDef.getQueryId(), shuffleStageNumber); - final Boolean isShuffleStageOutputEmpty = queryKernel.isStageOutputEmpty(shuffleStageId); - if (isFailOnEmptyInsertEnabled && Boolean.TRUE.equals(isShuffleStageOutputEmpty)) { - throw new MSQException(new InsertCannotBeEmptyFault(task.getDataSource())); + for (final StageId stageId : newStageIds) { + // Allocate segments, if this is the final stage of an ingestion. + if (MSQControllerTask.isIngestion(querySpec) + && stageId.getStageNumber() == queryDef.getFinalStageDefinition().getStageNumber()) { + populateSegmentsToGenerate(); } - final ClusterByPartitions partitionBoundaries = - queryKernel.getResultPartitionBoundariesForStage(shuffleStageId); - - final boolean mayHaveMultiValuedClusterByFields = - !queryKernel.getStageDefinition(shuffleStageId).mustGatherResultKeyStatistics() - || queryKernel.hasStageCollectorEncounteredAnyMultiValueField(shuffleStageId); - - segmentsToGenerate = generateSegmentIdsWithShardSpecs( - (DataSourceMSQDestination) task.getQuerySpec().getDestination(), - queryKernel.getStageDefinition(shuffleStageId).getSignature(), - queryKernel.getStageDefinition(shuffleStageId).getClusterBy(), - partitionBoundaries, - mayHaveMultiValuedClusterByFields, - isShuffleStageOutputEmpty + + final int workerCount = queryKernel.getWorkerInputsForStage(stageId).workerCount(); + final StageDefinition stageDef = queryKernel.getStageDefinition(stageId); + log.info( + "Query [%s] using workers[%d] for stage[%d], writing to[%s], shuffle[%s].", + stageId.getQueryId(), + workerCount, + stageId.getStageNumber(), + queryKernel.getStageOutputChannelMode(stageId), + stageDef.doesShuffle() ? stageDef.getShuffleSpec().kind() : "none" ); - log.info("Query[%s] generating %d segments.", queryDef.getQueryId(), segmentsToGenerate.size()); + workerManager.launchWorkersIfNeeded(workerCount); + stageRuntimesForLiveReports.put(stageId.getStageNumber(), new Interval(DateTimes.nowUtc(), DateTimes.MAX)); + startWorkForStage(queryDef, queryKernel, stageId.getStageNumber(), segmentsToGenerate); } + } while (!newStageIds.isEmpty()); + } - final int workerCount = queryKernel.getWorkerInputsForStage(stageId).workerCount(); - log.info( - "Query [%s] starting %d workers for stage %d.", - stageId.getQueryId(), - workerCount, - stageId.getStageNumber() - ); + /** + * Populate {@link #segmentsToGenerate} for ingestion. + */ + private void populateSegmentsToGenerate() throws IOException + { + // We need to find the shuffle details (like partition ranges) to generate segments. Generally this is + // going to correspond to the stage immediately prior to the final segment-generator stage. + int shuffleStageNumber = Iterables.getOnlyElement(queryDef.getFinalStageDefinition().getInputStageNumbers()); + + // The following logic assumes that output of all the stages without a shuffle retain the partition boundaries + // of the input to that stage. This may not always be the case. For example: GROUP BY queries without an + // ORDER BY clause. This works for QueryKit generated queries up until now, but it should be reworked as it + // might not always be the case. + while (!queryDef.getStageDefinition(shuffleStageNumber).doesShuffle()) { + shuffleStageNumber = + Iterables.getOnlyElement(queryDef.getStageDefinition(shuffleStageNumber).getInputStageNumbers()); + } - workerTaskLauncher.launchTasksIfNeeded(workerCount); - stageRuntimesForLiveReports.put(stageId.getStageNumber(), new Interval(DateTimes.nowUtc(), DateTimes.MAX)); - startWorkForStage(queryDef, queryKernel, stageId.getStageNumber(), segmentsToGenerate); + final StageId shuffleStageId = new StageId(queryDef.getQueryId(), shuffleStageNumber); + + final boolean isFailOnEmptyInsertEnabled = + MultiStageQueryContext.isFailOnEmptyInsertEnabled(querySpec.getQuery().context()); + final Boolean isShuffleStageOutputEmpty = queryKernel.isStageOutputEmpty(shuffleStageId); + if (isFailOnEmptyInsertEnabled && Boolean.TRUE.equals(isShuffleStageOutputEmpty)) { + throw new MSQException(new InsertCannotBeEmptyFault(getDataSourceForIngestion(querySpec))); } + + final ClusterByPartitions partitionBoundaries = + queryKernel.getResultPartitionBoundariesForStage(shuffleStageId); + + final boolean mayHaveMultiValuedClusterByFields = + !queryKernel.getStageDefinition(shuffleStageId).mustGatherResultKeyStatistics() + || queryKernel.hasStageCollectorEncounteredAnyMultiValueField(shuffleStageId); + + segmentsToGenerate = generateSegmentIdsWithShardSpecs( + (DataSourceMSQDestination) querySpec.getDestination(), + queryKernel.getStageDefinition(shuffleStageId).getSignature(), + queryKernel.getStageDefinition(shuffleStageId).getClusterBy(), + partitionBoundaries, + mayHaveMultiValuedClusterByFields, + isShuffleStageOutputEmpty + ); + + log.info("Query [%s] generating %d segments.", queryDef.getQueryId(), partitionBoundaries.size()); } /** @@ -3033,7 +2794,7 @@ private void updateLiveReportMaps() { logKernelStatus(queryDef.getQueryId(), queryKernel); - // Live reports: update stage phases, worker counts, partition counts. + // Live reports: update stage phases, worker counts, partition counts, output channel modes. for (StageId stageId : queryKernel.getActiveStages()) { final int stageNumber = stageId.getStageNumber(); stagePhasesForLiveReports.put(stageNumber, queryKernel.getStagePhase(stageId)); @@ -3045,15 +2806,20 @@ private void updateLiveReportMaps() ); } - stageWorkerCountsForLiveReports.putIfAbsent( + stageWorkerCountsForLiveReports.computeIfAbsent( stageNumber, - queryKernel.getWorkerInputsForStage(stageId).workerCount() + k -> queryKernel.getWorkerInputsForStage(stageId).workerCount() + ); + + stageOutputChannelModesForLiveReports.computeIfAbsent( + stageNumber, + k -> queryKernel.getStageOutputChannelMode(stageId) ); } // Live reports: update stage end times for any stages that just ended. for (StageId stageId : queryKernel.getActiveStages()) { - if (ControllerStagePhase.isSuccessfulTerminalPhase(queryKernel.getStagePhase(stageId))) { + if (queryKernel.getStagePhase(stageId).isSuccess()) { stageRuntimesForLiveReports.compute( queryKernel.getStageDefinition(stageId).getStageNumber(), (k, currentValue) -> { @@ -3070,21 +2836,144 @@ private void updateLiveReportMaps() /** * Issue cleanup commands to any stages that are effectivley finished, allowing them to delete their outputs. + * + * @return true if any stages were cleaned up */ - private void cleanUpEffectivelyFinishedStages() + private boolean cleanUpEffectivelyFinishedStages() { + final StageId finalStageId = queryDef.getFinalStageDefinition().getId(); + boolean didSomething = false; for (final StageId stageId : queryKernel.getEffectivelyFinishedStageIds()) { + if (finalStageId.equals(stageId) + && queryListener.readResults() + && (queryResultsReaderFuture == null || !queryResultsReaderFuture.isDone())) { + // Don't clean up final stage until results are done being read. + continue; + } + log.info("Query [%s] issuing cleanup order for stage %d.", queryDef.getQueryId(), stageId.getStageNumber()); contactWorkersForStage( queryKernel, - (netClient, taskId, workerNumber) -> netClient.postCleanupStage(taskId, stageId), queryKernel.getWorkerInputsForStage(stageId).workers(), - (ignore1) -> { - }, + (netClient, workerId, workerNumber) -> netClient.postCleanupStage(workerId, stageId), + (workerId, workerNumber) -> {}, false ); queryKernel.finishStage(stageId, true); + didSomething = true; + } + return didSomething; + } + + /** + * Start a {@link ControllerQueryResultsReader} that pushes results to our {@link QueryListener}. + * + * The reader runs in a single-threaded executor that is created by this method, and shut down when results + * are done being read. + */ + private void startQueryResultsReader() + { + if (queryResultsReaderFuture != null) { + throw new ISE("Already started"); + } + + final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); + final List taskIds = getTaskIds(); + + final InputChannelFactory inputChannelFactory; + + if (queryKernelConfig.isDurableStorage() || MSQControllerTask.writeResultsToDurableStorage(querySpec)) { + inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( + queryId(), + MSQTasks.makeStorageConnector(context.injector()), + closer, + MSQControllerTask.writeResultsToDurableStorage(querySpec) + ); + } else { + inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds); + } + + final FrameProcessorExecutor resultReaderExec = new FrameProcessorExecutor( + MoreExecutors.listeningDecorator( + Execs.singleThreaded(StringUtils.encodeForFormat("msq-result-reader[" + queryId() + "]"))) + ); + + final String cancellationId = "results-reader"; + ReadableConcatFrameChannel resultsChannel = null; + + try { + final InputChannels inputChannels = new InputChannelsImpl( + queryDef, + queryKernel.getResultPartitionsForStage(finalStageId), + inputChannelFactory, + () -> ArenaMemoryAllocator.createOnHeap(5_000_000), + resultReaderExec, + cancellationId + ); + + resultsChannel = ReadableConcatFrameChannel.open( + StreamSupport.stream(queryKernel.getResultPartitionsForStage(finalStageId).spliterator(), false) + .map( + readablePartition -> { + try { + return inputChannels.openChannel( + new StagePartition( + queryKernel.getStageDefinition(finalStageId).getId(), + readablePartition.getPartitionNumber() + ) + ); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + ) + .iterator() + ); + + final ControllerQueryResultsReader resultsReader = new ControllerQueryResultsReader( + resultsChannel, + queryDef.getFinalStageDefinition().getFrameReader(), + querySpec.getColumnMappings(), + resultsContext, + context.jsonMapper(), + queryListener + ); + + queryResultsReaderFuture = resultReaderExec.runFully(resultsReader, cancellationId); + + // When results are done being read, kick the main thread. + // Important: don't use FutureUtils.futureWithBaggage, because we need queryResultsReaderFuture to resolve + // *before* the main thread is kicked. + queryResultsReaderFuture.addListener( + () -> addToKernelManipulationQueue(holder -> {}), + Execs.directExecutor() + ); } + catch (Throwable e) { + // There was some issue setting up the result reader. Shut down the results channel and stop the executor. + final ReadableConcatFrameChannel finalResultsChannel = resultsChannel; + throw CloseableUtils.closeAndWrapInCatch( + e, + () -> CloseableUtils.closeAll( + finalResultsChannel, + () -> resultReaderExec.getExecutorService().shutdownNow() + ) + ); + } + + // Result reader is set up. Register with the query-wide closer. + closer.register(() -> { + try { + resultReaderExec.cancel(cancellationId); + } + catch (Exception e) { + throw new RuntimeException(e); + } + finally { + resultReaderExec.getExecutorService().shutdownNow(); + } + }); } /** @@ -3158,7 +3047,7 @@ private MSQErrorReport mapQueryColumnNameToOutputColumnName( .value(inbf.getValue()) .position(inbf.getPosition()) .build(), - task.getQuerySpec().getColumnMappings() + querySpec.getColumnMappings() ); } else if (workerErrorReport.getFault() instanceof InvalidFieldFault) { InvalidFieldFault iff = (InvalidFieldFault) workerErrorReport.getFault(); @@ -3172,7 +3061,7 @@ private MSQErrorReport mapQueryColumnNameToOutputColumnName( .column(iff.getColumn()) .errorMsg(iff.getErrorMsg()) .build(), - task.getQuerySpec().getColumnMappings() + querySpec.getColumnMappings() ); } else { return workerErrorReport; @@ -3185,7 +3074,7 @@ private MSQErrorReport mapQueryColumnNameToOutputColumnName( */ private interface TaskContactFn { - ListenableFuture contactTask(WorkerClient client, String taskId, int workerNumber); + ListenableFuture contactTask(WorkerClient client, String workerId, int workerNumber); } /** @@ -3193,7 +3082,6 @@ private interface TaskContactFn */ private interface TaskContactSuccess { - void onSuccess(String taskId); - + void onSuccess(String workerId, int workerNumber); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java new file mode 100644 index 000000000000..8e6fc72b6aa7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.base.Preconditions; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; +import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl; + +/** + * Class for determining how much JVM heap to allocate to various purposes for {@link Controller}. + * + * First, look at how much of total JVM heap that is dedicated for MSQ; see + * {@link MemoryIntrospector#usableMemoryInJvm()}. + * + * Then, we split up that total amount of memory into equally-sized portions per {@link Controller}; see + * {@link MemoryIntrospector#numQueriesInJvm()}. The number of controllers is based entirely on server configuration, + * which makes the calculation robust to different queries running simultaneously in the same JVM. + * + * Then, we split that up into a chunk used for input channels, and a chunk used for partition statistics. + */ +public class ControllerMemoryParameters +{ + /** + * Maximum number of bytes that we'll ever use for maxRetainedBytes of {@link ClusterByStatisticsCollectorImpl}. + */ + private static final long PARTITION_STATS_MAX_MEMORY = 300_000_000; + + /** + * Minimum number of bytes that is allowable for maxRetainedBytes of {@link ClusterByStatisticsCollectorImpl}. + */ + private static final long PARTITION_STATS_MIN_MEMORY = 25_000_000; + + /** + * Memory allocated to {@link ClusterByStatisticsCollectorImpl} as part of {@link ControllerQueryKernel}. + */ + private final int partitionStatisticsMaxRetainedBytes; + + public ControllerMemoryParameters(int partitionStatisticsMaxRetainedBytes) + { + this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes; + } + + /** + * Create an instance. + * + * @param memoryIntrospector memory introspector + * @param maxWorkerCount maximum worker count of the final stage + */ + public static ControllerMemoryParameters createProductionInstance( + final MemoryIntrospector memoryIntrospector, + final int maxWorkerCount + ) + { + final long usableMemoryInJvm = memoryIntrospector.usableMemoryInJvm(); + final int numControllersInJvm = memoryIntrospector.numQueriesInJvm(); + Preconditions.checkArgument(usableMemoryInJvm > 0, "Usable memory[%s] must be > 0", usableMemoryInJvm); + Preconditions.checkArgument(numControllersInJvm > 0, "Number of controllers[%s] must be > 0", numControllersInJvm); + Preconditions.checkArgument(maxWorkerCount > 0, "Number of workers[%s] must be > 0", maxWorkerCount); + + final long memoryPerController = usableMemoryInJvm / numControllersInJvm; + final long memoryForInputChannels = WorkerMemoryParameters.memoryNeededForInputChannels(maxWorkerCount); + final int partitionStatisticsMaxRetainedBytes = (int) Math.min( + memoryPerController - memoryForInputChannels, + PARTITION_STATS_MAX_MEMORY + ); + + if (partitionStatisticsMaxRetainedBytes < PARTITION_STATS_MIN_MEMORY) { + final long requiredMemory = memoryForInputChannels + PARTITION_STATS_MIN_MEMORY; + throw new MSQException( + new NotEnoughMemoryFault( + memoryIntrospector.computeJvmMemoryRequiredForUsableMemory(requiredMemory), + memoryIntrospector.totalMemoryInJvm(), + usableMemoryInJvm, + numControllersInJvm, + memoryIntrospector.numProcessorsInJvm() + ) + ); + } + + return new ControllerMemoryParameters(partitionStatisticsMaxRetainedBytes); + } + + /** + * Memory allocated to {@link ClusterByStatisticsCollectorImpl} as part of {@link ControllerQueryKernel}. + */ + public int getPartitionStatisticsMaxRetainedBytes() + { + return partitionStatisticsMaxRetainedBytes; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java new file mode 100644 index 000000000000..ae24704c9d8e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerQueryResultsReader.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.java.util.common.guava.Yielder; +import org.apache.druid.java.util.common.guava.Yielders; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.util.SqlStatementResourceHelper; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.planner.ColumnMapping; +import org.apache.druid.sql.calcite.planner.ColumnMappings; +import org.apache.druid.utils.CloseableUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +/** + * Used by {@link ControllerImpl} to read query results and hand them to a {@link QueryListener}. + */ +public class ControllerQueryResultsReader implements FrameProcessor +{ + private static final Logger log = new Logger(ControllerQueryResultsReader.class); + + private final ReadableFrameChannel in; + private final FrameReader frameReader; + private final ColumnMappings columnMappings; + private final ResultsContext resultsContext; + private final ObjectMapper jsonMapper; + private final QueryListener queryListener; + + private boolean wroteResultsStart; + + ControllerQueryResultsReader( + final ReadableFrameChannel in, + final FrameReader frameReader, + final ColumnMappings columnMappings, + final ResultsContext resultsContext, + final ObjectMapper jsonMapper, + final QueryListener queryListener + ) + { + this.in = in; + this.frameReader = frameReader; + this.columnMappings = columnMappings; + this.resultsContext = resultsContext; + this.jsonMapper = jsonMapper; + this.queryListener = queryListener; + } + + @Override + public List inputChannels() + { + return Collections.singletonList(in); + } + + @Override + public List outputChannels() + { + return Collections.emptyList(); + } + + @Override + public ReturnOrAwait runIncrementally(final IntSet readableInputs) + { + if (readableInputs.isEmpty()) { + return ReturnOrAwait.awaitAll(inputChannels().size()); + } + + if (!wroteResultsStart) { + final RowSignature querySignature = frameReader.signature(); + final ImmutableList.Builder mappedSignature = ImmutableList.builder(); + + for (final ColumnMapping mapping : columnMappings.getMappings()) { + mappedSignature.add( + new MSQResultsReport.ColumnAndType( + mapping.getOutputColumn(), + querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) + ) + ); + } + + queryListener.onResultsStart( + mappedSignature.build(), + resultsContext.getSqlTypeNames() + ); + + wroteResultsStart = true; + } + + // Read from query results channel, if it's open. + if (in.isFinished()) { + queryListener.onResultsComplete(); + return ReturnOrAwait.returnObject(null); + } else { + final Frame frame = in.read(); + Yielder rowYielder = Yielders.each( + SqlStatementResourceHelper.getResultSequence( + frame, + frameReader, + columnMappings, + resultsContext, + jsonMapper + ) + ); + + try { + while (!rowYielder.isDone()) { + if (queryListener.onResultRow(rowYielder.get())) { + rowYielder = rowYielder.next(null); + } else { + // Caller wanted to stop reading. + return ReturnOrAwait.returnObject(null); + } + } + } + finally { + CloseableUtils.closeAndSuppressExceptions(rowYielder, e -> log.warn(e, "Failed to close frame yielder")); + } + + return ReturnOrAwait.awaitAll(inputChannels().size()); + } + } + + @Override + public void cleanup() throws IOException + { + FrameProcessors.closeAll(inputChannels(), outputChannels()); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java index 93dbc0080045..8a7607d3159a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java @@ -60,25 +60,23 @@ public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workO @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerTaskId, - String queryId, - int stageNumber + StageId stageId ) { - return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber)); + return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId)); } @Override public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerTaskId, - String queryId, - int stageNumber, + StageId stageId, long timeChunk ) { return wrap( workerTaskId, client, - c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, queryId, stageNumber, timeChunk) + c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java index 8b6f26770a5d..bb782cb67d9a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java @@ -86,9 +86,10 @@ public class Limits public static final long MAX_WORKERS_FOR_PARALLEL_MERGE = 100; /** - * Max number of rows in the query reports of the SELECT queries run by MSQ. This ensures that the reports donot blow - * up for queries operating on larger datasets. The full result of the select query should be available once the - * MSQ is able to run async queries + * Max number of rows in the query reports of SELECT queries run by MSQ when using + * {@link org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination}. Reports in this mode contain a + * preview of actual query results, but not the full resultset.This ensures that the reports do not blow up in + * size for queries operating on larger datasets. */ public static final long MAX_SELECT_RESULT_ROWS = 3_000; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java new file mode 100644 index 000000000000..337e36d14efa --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospector.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.druid.msq.kernel.WorkOrder; + +/** + * Introspector used to generate {@link ControllerMemoryParameters}. + */ +public interface MemoryIntrospector +{ + /** + * Amount of total memory in the entire JVM. + */ + long totalMemoryInJvm(); + + /** + * Amount of memory usable for the multi-stage query engine in the entire JVM. + * + * This may be an expensive operation. For example, the production implementation {@link MemoryIntrospectorImpl} + * estimates size of all lookups as part of computing this value. + */ + long usableMemoryInJvm(); + + /** + * Amount of total JVM memory required for a particular amount of usable memory to be available. + * + * This may be an expensive operation. For example, the production implementation {@link MemoryIntrospectorImpl} + * estimates size of all lookups as part of computing this value. + */ + long computeJvmMemoryRequiredForUsableMemory(long usableMemory); + + /** + * Maximum number of queries that run simultaneously in this JVM. + * + * On workers, this is the maximum number of {@link Worker} that run simultaneously in this JVM. See + * {@link WorkerMemoryParameters} for how memory is divided among and within {@link WorkOrder} run by a worker. + * + * On controllers, this is the maximum number of {@link Controller} that run simultaneously. See + * {@link ControllerMemoryParameters} for how memory is used by controllers. + */ + int numQueriesInJvm(); + + /** + * Maximum number of processing threads that can be used at once in this JVM. + */ + int numProcessorsInJvm(); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java new file mode 100644 index 000000000000..f7cd501ed8fd --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/MemoryIntrospectorImpl.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.lookup.LookupExtractor; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainer; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; + +import java.util.List; + +/** + * Production implementation of {@link MemoryIntrospector}. + */ +public class MemoryIntrospectorImpl implements MemoryIntrospector +{ + private static final Logger log = new Logger(MemoryIntrospectorImpl.class); + + private final LookupExtractorFactoryContainerProvider lookupProvider; + private final long totalMemoryInJvm; + private final int numQueriesInJvm; + private final int numProcessorsInJvm; + private final double usableMemoryFraction; + + /** + * Create an introspector. + * + * @param lookupProvider provider of lookups; we use this to subtract lookup size from total JVM memory when + * computing usable memory + * @param totalMemoryInJvm maximum JVM heap memory + * @param usableMemoryFraction fraction of JVM memory, after subtracting lookup overhead, that we consider usable + * for multi-stage queries + * @param numQueriesInJvm maximum number of {@link Controller} or {@link Worker} that may run concurrently + * @param numProcessorsInJvm size of processing thread pool, typically {@link DruidProcessingConfig#getNumThreads()} + */ + public MemoryIntrospectorImpl( + final LookupExtractorFactoryContainerProvider lookupProvider, + final long totalMemoryInJvm, + final double usableMemoryFraction, + final int numQueriesInJvm, + final int numProcessorsInJvm + ) + { + this.lookupProvider = lookupProvider; + this.totalMemoryInJvm = totalMemoryInJvm; + this.numQueriesInJvm = numQueriesInJvm; + this.numProcessorsInJvm = numProcessorsInJvm; + this.usableMemoryFraction = usableMemoryFraction; + } + + @Override + public long totalMemoryInJvm() + { + return totalMemoryInJvm; + } + + @Override + public long usableMemoryInJvm() + { + final long totalMemory = totalMemoryInJvm(); + final long totalLookupFootprint = computeTotalLookupFootprint(true); + return Math.max( + 0, + (long) ((totalMemory - totalLookupFootprint) * usableMemoryFraction) + ); + } + + @Override + public long computeJvmMemoryRequiredForUsableMemory(long usableMemory) + { + final long totalLookupFootprint = computeTotalLookupFootprint(false); + return (long) Math.ceil(usableMemory / usableMemoryFraction + totalLookupFootprint); + } + + @Override + public int numQueriesInJvm() + { + return numQueriesInJvm; + } + + @Override + public int numProcessorsInJvm() + { + return numProcessorsInJvm; + } + + /** + * Compute and return total estimated lookup footprint. + * + * Correctness of this approach depends on lookups being loaded *before* calling this method. Luckily, this is the + * typical mode of operation, since by default druid.lookup.enableLookupSyncOnStartup = true. + * + * @param logFootprint whether footprint should be logged + */ + private long computeTotalLookupFootprint(final boolean logFootprint) + { + final List lookupNames = ImmutableList.copyOf(lookupProvider.getAllLookupNames()); + + long lookupFootprint = 0; + + for (final String lookupName : lookupNames) { + final LookupExtractorFactoryContainer container = lookupProvider.get(lookupName).orElse(null); + + if (container != null) { + try { + final LookupExtractor extractor = container.getLookupExtractorFactory().get(); + lookupFootprint += extractor.estimateHeapFootprint(); + } + catch (Exception e) { + log.noStackTrace().warn(e, "Failed to load lookup[%s] for size estimation. Skipping.", lookupName); + } + } + } + + if (logFootprint) { + log.info("Lookup footprint: lookup count[%d], total bytes[%,d].", lookupNames.size(), lookupFootprint); + } + + return lookupFootprint; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java new file mode 100644 index 000000000000..7e7fc3d3d6f3 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelUtils; + +/** + * Mode for stage output channels. Provided to workers in {@link WorkOrder#getOutputChannelMode()}. + */ +public enum OutputChannelMode +{ + /** + * In-memory output channels. Stage shuffle data does not hit disk. This mode requires a consumer stage to run + * at the same time as its corresponding producer stage. See {@link ControllerQueryKernelUtils#computeStageGroups} for the + * logic that determines when we can use in-memory channels. + */ + MEMORY("memory"), + + /** + * Local file output channels. Stage shuffle data is stored in files on disk on the producer, and served via HTTP + * to the consumer. + */ + LOCAL_STORAGE("localStorage"), + + /** + * Durable storage output channels. Stage shuffle data is written by producers to durable storage (e.g. cloud + * storage), and is read from durable storage by consumers. + */ + DURABLE_STORAGE_INTERMEDIATE("durableStorage"), + + /** + * Like {@link #DURABLE_STORAGE_INTERMEDIATE}, but a special case for the final stage + * {@link QueryDefinition#getFinalStageDefinition()}. The structure of files in deep storage is somewhat different. + */ + DURABLE_STORAGE_QUERY_RESULTS("durableStorageQueryResults"); + + private final String name; + + OutputChannelMode(String name) + { + this.name = name; + } + + @JsonCreator + public static OutputChannelMode fromString(final String s) + { + for (final OutputChannelMode mode : values()) { + if (mode.toString().equals(s)) { + return mode; + } + } + + throw new IAE("No such outputChannelMode[%s]", s); + } + + /** + * Whether this mode involves writing to durable storage. + */ + public boolean isDurable() + { + return this == DURABLE_STORAGE_INTERMEDIATE || this == DURABLE_STORAGE_QUERY_RESULTS; + } + + @Override + @JsonValue + public String toString() + { + return name; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java new file mode 100644 index 000000000000..997fe4c8682d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryListener.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * Object passed to {@link Controller#run(QueryListener)} to enable retrieval of results, status, counters, etc. + */ +public interface QueryListener +{ + /** + * Whether this listener is meant to receive results. + */ + boolean readResults(); + + /** + * Called when results start coming in. + * + * @param signature signature of results + * @param sqlTypeNames SQL type names of results; same length as the signature + */ + void onResultsStart( + List signature, + @Nullable List sqlTypeNames + ); + + /** + * Called for each result row. Follows a call to {@link #onResultsStart(List, List)}. + * + * @param row result row + * + * @return whether the controller should keep reading results + */ + boolean onResultRow(Object[] row); + + /** + * Called after the last result has been delivered by {@link #onResultRow(Object[])}. Only called if results are + * actually complete. If results are truncated due to {@link #readResults()} or {@link #onResultRow(Object[])} + * returning false, this method is not called. + */ + void onResultsComplete(); + + /** + * Called when the query is complete and a report is available. After this method is called, no other methods + * will be called. The report will not include a {@link MSQResultsReport}. + */ + void onQueryComplete(MSQTaskReportPayload report); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java new file mode 100644 index 000000000000..9e565bb75a5d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ResultsContext.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.sql.calcite.run.SqlResults; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Holder for objects needed to interpret SQL results. + */ +public class ResultsContext +{ + private final List sqlTypeNames; + private final SqlResults.Context sqlResultsContext; + + public ResultsContext( + final List sqlTypeNames, + final SqlResults.Context sqlResultsContext + ) + { + this.sqlTypeNames = sqlTypeNames; + this.sqlResultsContext = sqlResultsContext; + } + + @Nullable + public List getSqlTypeNames() + { + return sqlTypeNames; + } + + @Nullable + public SqlResults.Context getSqlResultsContext() + { + return sqlResultsContext; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResultsContext that = (ResultsContext) o; + return Objects.equals(sqlTypeNames, that.sqlTypeNames) + && Objects.equals(sqlResultsContext, that.sqlResultsContext); + } + + @Override + public int hashCode() + { + return Objects.hash(sqlTypeNames, sqlResultsContext); + } + + @Override + public String toString() + { + return "ResultsContext{" + + "sqlTypeNames=" + sqlTypeNames + + ", sqlResultsContext=" + sqlResultsContext + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java new file mode 100644 index 000000000000..d5b3d41d7a2b --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RetryCapableWorkerManager.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +/** + * Expanded {@link WorkerManager} interface with methods to support retrying workers. + */ +public interface RetryCapableWorkerManager extends WorkerManager +{ + /** + * Queues worker for relaunch. A noop if the worker is already in the queue. + */ + void submitForRelaunch(int workerNumber); + + /** + * Report a worker that failed without active orders. To be retried if it is requried for future stages only. + */ + void reportFailedInactiveWorker(int workerNumber); + + /** + * Checks if the controller has canceled the input taskId. This method is used in {@link ControllerImpl} + * to figure out if the worker taskId is canceled by the controller. If yes, the errors from that worker taskId + * are ignored for the error reports. + * + * @return true if task is canceled by the controller, else false + */ + boolean isTaskCanceledByController(String taskId); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java index 1546766f856f..d4eaef600125 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java @@ -104,7 +104,7 @@ public class SegmentLoadStatusFetcher implements AutoCloseable public SegmentLoadStatusFetcher( BrokerClient brokerClient, ObjectMapper objectMapper, - String taskId, + String queryId, String datasource, Set dataSegments, boolean doWait @@ -128,7 +128,9 @@ public SegmentLoadStatusFetcher( totalSegmentsGenerated )); this.doWait = doWait; - this.executorService = MoreExecutors.listeningDecorator(Execs.singleThreaded(taskId + "-segment-load-waiter-%d")); + this.executorService = MoreExecutors.listeningDecorator( + Execs.singleThreaded(StringUtils.encodeForFormat(queryId) + "-segment-load-waiter-%d") + ); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java index 5c02a79f89a3..572051124a74 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java @@ -27,26 +27,27 @@ import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import java.io.Closeable; import java.io.IOException; /** - * Client for multi-stage query workers. Used by the controller task. + * Client for {@link Worker}. Each instance is scoped to a single query, and can communicate with all workers for + * that particular query. */ -public interface WorkerClient extends AutoCloseable +public interface WorkerClient extends Closeable { /** * Worker's client method to add a {@link WorkOrder} to the worker to work on */ - ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workOrder); + ListenableFuture postWorkOrder(String workerId, WorkOrder workOrder); /** * Fetches the {@link ClusterByStatisticsSnapshot} from a worker. This is intended to be used by the * {@link WorkerSketchFetcher} under PARALLEL or AUTO modes. */ ListenableFuture fetchClusterByStatisticsSnapshot( - String workerTaskId, - String queryId, - int stageNumber + String workerId, + StageId stageId ); /** @@ -54,9 +55,8 @@ ListenableFuture fetchClusterByStatisticsSnapshot( * This is intended to be used by the {@link WorkerSketchFetcher} under SEQUENTIAL or AUTO modes. */ ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( - String workerTaskId, - String queryId, - int stageNumber, + String workerId, + StageId stageId, long timeChunk ); @@ -65,28 +65,26 @@ ListenableFuture fetchClusterByStatisticsSnapshotFo * controller after collating the result statistics from all the workers processing the query */ ListenableFuture postResultPartitionBoundaries( - String workerTaskId, + String workerId, StageId stageId, ClusterByPartitions partitionBoundaries ); /** - * Worker's client method to inform that the work has been done, and it can initiate cleanup and shutdown - * @param workerTaskId + * Fetches counters from a worker. */ - ListenableFuture postFinish(String workerTaskId); + ListenableFuture getCounters(String workerId); /** - * Fetches all the counters gathered by that worker - * @param workerTaskId + * Worker's client method that informs it that the results and resources for the given stage are no longer required + * and that they can be cleaned up */ - ListenableFuture getCounters(String workerTaskId); + ListenableFuture postCleanupStage(String workerId, StageId stageId); /** - * Worker's client method that informs it that the results and resources for the given stage are no longer required - * and that they can be cleaned up + * Worker's client method to inform that the work has been done, and it can initiate cleanup and shutdown. */ - ListenableFuture postCleanupStage(String workerTaskId, StageId stageId); + ListenableFuture postFinish(String workerId); /** * Fetch some data from a worker and add it to the provided channel. The exact amount of data is determined @@ -96,13 +94,16 @@ ListenableFuture postResultPartitionBoundaries( * kind of unrecoverable exception). */ ListenableFuture fetchChannelData( - String workerTaskId, + String workerId, StageId stageId, int partitionNumber, long offset, ReadableByteChunksFrameChannel channel ); + /** + * Close this client and release resources. + */ @Override void close() throws IOException; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/RetryTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerFailureListener.java similarity index 72% rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/RetryTask.java rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerFailureListener.java index 39fb1e688ecf..9bc4ed56cde7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/RetryTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerFailureListener.java @@ -17,17 +17,18 @@ * under the License. */ -package org.apache.druid.msq.indexing; +package org.apache.druid.msq.exec; +import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.error.MSQFault; -public interface RetryTask +/** + * Notifies users of {@link WorkerManager} when a worker fails. + */ +public interface WorkerFailureListener { /** - * Retry task when {@link MSQFault} is encountered. - * - * @param workerTask - * @param msqFault + * Fires when a worker launched or monitoring by {@link WorkerManager} fails. */ - void retry(MSQWorkerTask workerTask, MSQFault msqFault); + void onFailure(MSQWorkerTask workerTask, MSQFault msqFault); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java new file mode 100644 index 000000000000..ebce4821d591 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.msq.indexing.WorkerCount; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Used by {@link ControllerImpl} to discover and manage workers. + * + * Worker managers capable of retrying should extend {@link RetryCapableWorkerManager} (an extension of this interface). + */ +public interface WorkerManager +{ + int UNKNOWN_WORKER_NUMBER = -1; + + /** + * Starts this manager. + * + * Returns a future that resolves successfully when all workers end successfully or are canceled. The returned future + * resolves to an exception if one of the workers fails without being explicitly canceled, or if something else + * goes wrong. + */ + ListenableFuture start(); + + /** + * Launch additional workers, if needed, to bring the number of running workers up to {@code workerCount}. + * Blocks until the requested workers are launched. If enough workers are already running, this method does nothing. + */ + void launchWorkersIfNeeded(int workerCount) throws InterruptedException; + + /** + * Blocks until workers with the provided worker numbers (indexes into {@link #getWorkerIds()} are ready to be + * contacted for work. + */ + void waitForWorkers(Set workerNumbers) throws InterruptedException; + + /** + * List of currently-active workers. + */ + List getWorkerIds(); + + /** + * Number of currently-active and currently-pending workers. + */ + WorkerCount getWorkerCount(); + + /** + * Worker number of a worker with the provided ID, or {@link #UNKNOWN_WORKER_NUMBER} if none exists. + */ + int getWorkerNumber(String workerId); + + /** + * Whether an active worker exists with the provided ID. + */ + boolean isWorkerActive(String workerId); + + /** + * Map of worker number to list of all workers currently running with that number. More than one worker per number + * only occurs when fault tolerance is enabled. + */ + Map> getWorkerStats(); + + /** + * Blocks until all workers exit. Returns quietly, no matter whether there was an exception associated with the + * future from {@link #start()} or not. + * + * @param interrupt whether to interrupt currently-running work + */ + void stop(boolean interrupt); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index a09d0508485d..b36b1b4155a8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -473,6 +473,16 @@ static int computeMaxWorkers( ); } + /** + * Computes the amount of memory needed to read a single partition from a given number of workers. + */ + static long memoryNeededForInputChannels(final int numInputWorkers) + { + // Workers that read sorted inputs must open all channels at once to do an N-way merge. Calculate memory needs. + // Requirement: one input frame per worker, one buffered output frame. + return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1); + } + /** * Maximum number of workers that may exist in the current JVM. */ @@ -563,13 +573,6 @@ private static long estimateUsableMemory(final int numWorkersInJvm, final long e return estimatedTotalBundleMemory + (estimateStatOverHeadPerWorker * numWorkersInJvm); } - private static long memoryNeededForInputChannels(final int numInputWorkers) - { - // Workers that read sorted inputs must open all channels at once to do an N-way merge. Calculate memory needs. - // Requirement: one input frame per worker, one buffered output frame. - return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1); - } - private static long memoryNeededForHashPartitioning(final int numOutputPartitions) { // One standard frame for each processor output. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index 271ce8ff0709..73f151fcdaa9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -30,7 +30,6 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.function.TriConsumer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; import org.apache.druid.msq.indexing.error.MSQFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; import org.apache.druid.msq.kernel.StageId; @@ -58,23 +57,23 @@ public class WorkerSketchFetcher implements AutoCloseable private static final int DEFAULT_THREAD_COUNT = 4; private final WorkerClient workerClient; - private final MSQWorkerTaskLauncher workerTaskLauncher; + private final WorkerManager workerManager; private final boolean retryEnabled; - private AtomicReference isError = new AtomicReference<>(); + private final AtomicReference isError = new AtomicReference<>(); final ExecutorService executorService; public WorkerSketchFetcher( WorkerClient workerClient, - MSQWorkerTaskLauncher workerTaskLauncher, + WorkerManager workerManager, boolean retryEnabled ) { this.workerClient = workerClient; this.executorService = Execs.multiThreaded(DEFAULT_THREAD_COUNT, "SketchFetcherThreadPool-%d"); - this.workerTaskLauncher = workerTaskLauncher; + this.workerManager = workerManager; this.retryEnabled = retryEnabled; } @@ -93,21 +92,14 @@ public void inMemoryFullSketchMerging( for (String taskId : taskIds) { try { - int workerNumber = MSQTasks.workerFromTaskId(taskId); + int workerNumber = workerManager.getWorkerNumber(taskId); executorService.submit(() -> { fetchStatsFromWorker( kernelActions, - () -> workerClient.fetchClusterByStatisticsSnapshot( - taskId, - stageId.getQueryId(), - stageId.getStageNumber() - ), + () -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId), taskId, - (kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForAllTimeChunks( - stageId, - workerNumber, - snapshot - ), + (kernel, snapshot) -> + kernel.mergeClusterByStatisticsCollectorForAllTimeChunks(stageId, workerNumber, snapshot), retryOperation ); }); @@ -135,9 +127,14 @@ private void fetchStatsFromWorker( executorService.shutdownNow(); return; } - int worker = MSQTasks.workerFromTaskId(taskId); + int worker = workerManager.getWorkerNumber(taskId); + if (worker == WorkerManager.UNKNOWN_WORKER_NUMBER) { + log.info("Task[%s] is no longer the latest task for worker[%d]. Skipping fetch.", taskId, worker); + return; + } + try { - workerTaskLauncher.waitUntilWorkersReady(ImmutableSet.of(worker)); + workerManager.waitForWorkers(ImmutableSet.of(worker)); } catch (InterruptedException interruptedException) { isError.compareAndSet(null, interruptedException); @@ -146,12 +143,8 @@ private void fetchStatsFromWorker( } // if task is not the latest task. It must have retried. - if (!workerTaskLauncher.isTaskLatest(taskId)) { - log.info( - "Task[%s] is no longer the latest task for worker[%d], hence ignoring fetching stats from this worker", - taskId, - worker - ); + if (!workerManager.isWorkerActive(taskId)) { + log.info("Task[%s] is no longer the latest task for worker[%d]. Skipping fetch.", taskId, worker); return; } @@ -250,7 +243,7 @@ public void sequentialTimeChunkMerging( completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().forEach((timeChunk, wks) -> { for (String taskId : tasks) { - int workerNumber = MSQTasks.workerFromTaskId(taskId); + int workerNumber = workerManager.getWorkerNumber(taskId); if (wks.contains(workerNumber)) { noBoundaries.remove(taskId); executorService.submit(() -> { @@ -258,8 +251,7 @@ public void sequentialTimeChunkMerging( kernelActions, () -> workerClient.fetchClusterByStatisticsSnapshotForTimeChunk( taskId, - stageId.getQueryId(), - stageId.getStageNumber(), + new StageId(stageId.getQueryId(), stageId.getStageNumber()), timeChunk ), taskId, @@ -281,7 +273,7 @@ public void sequentialTimeChunkMerging( for (String taskId : noBoundaries) { kernelActions.accept( kernel -> { - final int workerNumber = MSQTasks.workerFromTaskId(taskId); + final int workerNumber = workerManager.getWorkerNumber(taskId); kernel.mergeClusterByStatisticsCollectorForAllTimeChunks( stageId, workerNumber, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java new file mode 100644 index 000000000000..831ea645e40a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStats.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.indexer.TaskState; + +import java.util.Objects; + +public class WorkerStats +{ + private final String workerId; + private final TaskState state; + private final long durationMs; + private final long pendingMs; + + @JsonCreator + public WorkerStats( + @JsonProperty("workerId") String workerId, + @JsonProperty("state") TaskState state, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("pendingMs") long pendingMs + ) + { + this.workerId = workerId; + this.state = state; + this.durationMs = durationMs; + this.pendingMs = pendingMs; + } + + @JsonProperty + public String getWorkerId() + { + return workerId; + } + + @JsonProperty + public TaskState getState() + { + return state; + } + + @JsonProperty("durationMs") + public long getDuration() + { + return durationMs; + } + + @JsonProperty("pendingMs") + public long getPendingTimeInMs() + { + return pendingMs; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerStats that = (WorkerStats) o; + return durationMs == that.durationMs + && pendingMs == that.pendingMs + && Objects.equals(workerId, that.workerId) + && state == that.state; + } + + @Override + public int hashCode() + { + return Objects.hash(workerId, state, durationMs, pendingMs); + } + + @Override + public String toString() + { + return "WorkerStats{" + + "workerId='" + workerId + '\'' + + ", state=" + state + + ", durationMs=" + durationMs + + ", pendingMs=" + pendingMs + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java new file mode 100644 index 000000000000..92f16a631d9f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/IndexerMemoryManagementModule.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.indexing.worker.config.WorkerConfig; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; +import org.apache.druid.utils.JvmUtils; + +/** + * Provides {@link MemoryIntrospector} for multi-task-per-JVM model. + * + * @see PeonMemoryManagementModule for single-task-per-JVM model used on {@link org.apache.druid.cli.CliPeon} + */ +@LoadScope(roles = NodeRole.INDEXER_JSON_NAME) +public class IndexerMemoryManagementModule implements DruidModule +{ + /** + * Allocate up to 75% of memory for MSQ-related stuff (if all running tasks are MSQ tasks). + */ + private static final double USABLE_MEMORY_FRACTION = 0.75; + + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + @LazySingleton + public Bouncer makeProcessorBouncer(final DruidProcessingConfig processingConfig) + { + return new Bouncer(processingConfig.getNumThreads()); + } + + @Provides + @LazySingleton + public MemoryIntrospector createMemoryIntrospector( + final LookupExtractorFactoryContainerProvider lookupProvider, + final DruidProcessingConfig processingConfig, + final WorkerConfig workerConfig + ) + { + return new MemoryIntrospectorImpl( + lookupProvider, + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + USABLE_MEMORY_FRACTION, + workerConfig.getCapacity(), + processingConfig.getNumThreads() + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java index 125a66331e60..f4d24cfc5c4c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQIndexingModule.java @@ -23,11 +23,6 @@ import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.common.collect.ImmutableList; import com.google.inject.Binder; -import com.google.inject.Provides; -import org.apache.druid.discovery.NodeRole; -import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.guice.LazySingleton; -import org.apache.druid.guice.annotations.Self; import org.apache.druid.initialization.DruidModule; import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterSnapshotsSerializer; @@ -94,11 +89,9 @@ import org.apache.druid.msq.querykit.results.QueryResultFrameProcessorFactory; import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessorFactory; import org.apache.druid.msq.util.PassthroughAggregatorFactory; -import org.apache.druid.query.DruidProcessingConfig; import java.util.Collections; import java.util.List; -import java.util.Set; /** * Module that adds {@link MSQControllerTask}, {@link MSQWorkerTask}, and dependencies. @@ -206,17 +199,4 @@ public List getJacksonModules() public void configure(Binder binder) { } - - @Provides - @LazySingleton - public Bouncer makeBouncer(final DruidProcessingConfig processingConfig, @Self Set nodeRoles) - { - if (nodeRoles.contains(NodeRole.PEON) && !nodeRoles.contains(NodeRole.INDEXER)) { - // CliPeon -> use only one thread regardless of configured # of processing threads. This matches the expected - // resource usage pattern for CliPeon-based tasks (one task / one working thread per JVM). - return new Bouncer(1); - } else { - return new Bouncer(processingConfig.getNumThreads()); - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java index 8e381e50bd01..ea6eb364cece 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MSQSqlModule.java @@ -25,7 +25,6 @@ import org.apache.druid.discovery.NodeRole; import org.apache.druid.guice.LazySingleton; import org.apache.druid.guice.annotations.LoadScope; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.initialization.DruidModule; import org.apache.druid.metadata.input.InputSourceModule; import org.apache.druid.msq.sql.MSQTaskSqlEngine; @@ -62,11 +61,11 @@ public void configure(Binder binder) } @Provides - @MSQ + @MultiStageQuery @LazySingleton public SqlStatementFactory makeMSQSqlStatementFactory( final MSQTaskSqlEngine engine, - SqlToolbox toolbox + final SqlToolbox toolbox ) { return new SqlStatementFactory(toolbox.withEngine(engine)); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java index 986b1cfb2545..ba017a2c7863 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/MultiStageQuery.java @@ -26,6 +26,9 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +/** + * Binding annotation for implements of interfaces that are MSQ (MultiStageQuery) focused. + */ @Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @BindingAnnotation diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java new file mode 100644 index 000000000000..9e814c082781 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/PeonMemoryManagementModule.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; +import org.apache.druid.utils.JvmUtils; + +/** + * Provides {@link MemoryIntrospector} for single-task-per-JVM model. + * + * @see IndexerMemoryManagementModule for multi-task-per-JVM model used on {@link org.apache.druid.cli.CliIndexer} + */ +@LoadScope(roles = NodeRole.PEON_JSON_NAME) +public class PeonMemoryManagementModule implements DruidModule +{ + /** + * Peons have a single worker per JVM. + */ + private static final int NUM_WORKERS_IN_JVM = 1; + + /** + * Peons may have more than one processing thread, but we currently only use one of them. + */ + private static final int NUM_PROCESSING_THREADS = 1; + + /** + * Allocate 75% of memory for MSQ-related stuff. + */ + private static final double USABLE_MEMORY_FRACTION = 0.75; + + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + @LazySingleton + public Bouncer makeProcessorBouncer() + { + return new Bouncer(NUM_PROCESSING_THREADS); + } + + @Provides + @LazySingleton + public MemoryIntrospector createMemoryIntrospector( + final LookupExtractorFactoryContainerProvider lookupProvider, + final Bouncer bouncer + ) + { + return new MemoryIntrospectorImpl( + lookupProvider, + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + USABLE_MEMORY_FRACTION, + NUM_WORKERS_IN_JVM, + bouncer.getMaxCount() + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java index 52531294f341..d09f8613fa7e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/guice/SqlTaskModule.java @@ -19,7 +19,6 @@ package org.apache.druid.msq.guice; -import com.fasterxml.jackson.databind.Module; import com.google.inject.Binder; import org.apache.druid.discovery.NodeRole; import org.apache.druid.guice.Jerseys; @@ -29,9 +28,6 @@ import org.apache.druid.msq.sql.resources.SqlStatementResource; import org.apache.druid.msq.sql.resources.SqlTaskResource; -import java.util.Collections; -import java.util.List; - /** * Module for adding the {@link SqlTaskResource} endpoint to the Broker. */ @@ -47,10 +43,4 @@ public void configure(Binder binder) LifecycleModule.register(binder, SqlStatementResource.class); Jerseys.addResource(binder, SqlStatementResource.class); } - - @Override - public List getJacksonModules() - { - return Collections.emptyList(); - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java index aeee05e75067..3ff71c3e1b77 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java @@ -20,56 +20,110 @@ package org.apache.druid.msq.indexing; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.inject.Injector; import com.google.inject.Key; -import org.apache.druid.client.coordinator.CoordinatorClient; import org.apache.druid.guice.annotations.Self; -import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexing.common.TaskToolbox; import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.indexing.common.task.IndexTaskUtils; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.io.Closer; -import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerMemoryParameters; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.exec.WorkerClient; -import org.apache.druid.msq.exec.WorkerManagerClient; +import org.apache.druid.msq.exec.WorkerFailureListener; +import org.apache.druid.msq.exec.WorkerManager; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.client.IndexerWorkerClient; -import org.apache.druid.msq.indexing.client.IndexerWorkerManagerClient; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.MSQWarnings; +import org.apache.druid.msq.indexing.error.UnknownFault; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.table.TableInputSpecSlicer; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.DruidMetrics; +import org.apache.druid.query.QueryContext; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.realtime.firehose.ChatHandler; import org.apache.druid.server.DruidNode; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + /** * Implementation for {@link ControllerContext} required to run multi-stage queries as indexing tasks. */ public class IndexerControllerContext implements ControllerContext { + private static final Logger log = new Logger(IndexerControllerContext.class); + + private final MSQControllerTask task; private final TaskToolbox toolbox; private final Injector injector; private final ServiceClientFactory clientFactory; private final OverlordClient overlordClient; - private final WorkerManagerClient workerManager; + private final ServiceMetricEvent.Builder metricBuilder; public IndexerControllerContext( + final MSQControllerTask task, final TaskToolbox toolbox, final Injector injector, final ServiceClientFactory clientFactory, final OverlordClient overlordClient ) { + this.task = task; this.toolbox = toolbox; this.injector = injector; this.clientFactory = clientFactory; this.overlordClient = overlordClient; - this.workerManager = new IndexerWorkerManagerClient(overlordClient); + this.metricBuilder = new ServiceMetricEvent.Builder(); + IndexTaskUtils.setTaskDimensions(metricBuilder, task); + } + + @Override + public ControllerQueryKernelConfig queryKernelConfig( + final MSQSpec querySpec, + final QueryDefinition queryDef + ) + { + final MemoryIntrospector memoryIntrospector = injector.getInstance(MemoryIntrospector.class); + final ControllerMemoryParameters memoryParameters = + ControllerMemoryParameters.createProductionInstance( + memoryIntrospector, + queryDef.getFinalStageDefinition().getMaxWorkerCount() + ); + + final ControllerQueryKernelConfig config = makeQueryKernelConfig(querySpec, memoryParameters); + + log.debug( + "Query[%s] using %s[%s], %s[%s], %s[%s].", + queryDef.getQueryId(), + MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, + config.isDurableStorage(), + MultiStageQueryContext.CTX_FAULT_TOLERANCE, + config.isFaultTolerant(), + MultiStageQueryContext.CTX_MAX_CONCURRENT_STAGES, + config.getMaxConcurrentStages() + ); + + return config; } @Override - public ServiceEmitter emitter() + public void emitMetric(String metric, Number value) { - return toolbox.getEmitter(); + toolbox.getEmitter().emit(metricBuilder.setMetric(metric, value)); } @Override @@ -91,9 +145,15 @@ public DruidNode selfNode() } @Override - public CoordinatorClient coordinatorClient() + public InputSpecSlicer newTableInputSpecSlicer() { - return toolbox.getCoordinatorClient(); + final SegmentSource includeSegmentSource = + MultiStageQueryContext.getSegmentSources(task.getQuerySpec().getQuery().context()); + return new TableInputSpecSlicer( + toolbox.getCoordinatorClient(), + toolbox.getTaskActionClient(), + includeSegmentSource + ); } @Override @@ -103,29 +163,126 @@ public TaskActionClient taskActionClient() } @Override - public WorkerClient taskClientFor(Controller controller) + public WorkerClient newWorkerClient() { - // Ignore controller parameter. return new IndexerWorkerClient(clientFactory, overlordClient, jsonMapper()); } @Override public void registerController(Controller controller, final Closer closer) { - ChatHandler chatHandler = new ControllerChatHandler(toolbox, controller); - toolbox.getChatHandlerProvider().register(controller.id(), chatHandler, false); - closer.register(() -> toolbox.getChatHandlerProvider().unregister(controller.id())); + ChatHandler chatHandler = new ControllerChatHandler( + controller, + task.getDataSource(), + toolbox.getAuthorizerMapper() + ); + toolbox.getChatHandlerProvider().register(controller.queryId(), chatHandler, false); + closer.register(() -> toolbox.getChatHandlerProvider().unregister(controller.queryId())); } @Override - public WorkerManagerClient workerManager() + public WorkerManager newWorkerManager( + final String queryId, + final MSQSpec querySpec, + final ControllerQueryKernelConfig queryKernelConfig, + final WorkerFailureListener workerFailureListener + ) { - return workerManager; + return new MSQWorkerTaskLauncher( + queryId, + task.getDataSource(), + overlordClient, + workerFailureListener, + makeTaskContext(querySpec, queryKernelConfig, task.getContext()), + // 10 minutes +- 2 minutes jitter + TimeUnit.SECONDS.toMillis(600 + ThreadLocalRandom.current().nextInt(-4, 5) * 30L) + ); } - @Override - public void writeReports(String controllerTaskId, TaskReport.ReportMap reports) + /** + * Helper method for {@link #queryKernelConfig(MSQSpec, QueryDefinition)}. Also used in tests. + */ + public static ControllerQueryKernelConfig makeQueryKernelConfig( + final MSQSpec querySpec, + final ControllerMemoryParameters memoryParameters + ) { - toolbox.getTaskReportFileWriter().write(controllerTaskId, reports); + final QueryContext queryContext = querySpec.getQuery().context(); + final int maxConcurrentStages = MultiStageQueryContext.getMaxConcurrentStages(queryContext); + final boolean isFaultToleranceEnabled = MultiStageQueryContext.isFaultToleranceEnabled(queryContext); + final boolean isDurableStorageEnabled; + + if (isFaultToleranceEnabled) { + if (!queryContext.containsKey(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE)) { + // if context key not set, enable durableStorage automatically. + isDurableStorageEnabled = true; + } else { + // if context key is set, and durableStorage is turned on. + if (MultiStageQueryContext.isDurableStorageEnabled(queryContext)) { + isDurableStorageEnabled = true; + } else { + throw new MSQException( + UnknownFault.forMessage( + StringUtils.format( + "Context param[%s] cannot be explicitly set to false when context param[%s] is" + + " set to true. Either remove the context param[%s] or explicitly set it to true.", + MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, + MultiStageQueryContext.CTX_FAULT_TOLERANCE, + MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE + ) + ) + ); + } + } + } else { + isDurableStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext); + } + + return ControllerQueryKernelConfig + .builder() + .pipeline(maxConcurrentStages > 1) + .durableStorage(isDurableStorageEnabled) + .faultTolerance(isFaultToleranceEnabled) + .destination(querySpec.getDestination()) + .maxConcurrentStages(maxConcurrentStages) + .maxRetainedPartitionSketchBytes(memoryParameters.getPartitionStatisticsMaxRetainedBytes()) + .build(); + } + + /** + * Helper method for {@link #newWorkerManager}, split out to be used in tests. + * + * @param querySpec MSQ query spec; used for + */ + public static Map makeTaskContext( + final MSQSpec querySpec, + final ControllerQueryKernelConfig queryKernelConfig, + final Map controllerTaskContext + ) + { + final ImmutableMap.Builder taskContextOverridesBuilder = ImmutableMap.builder(); + final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(querySpec.getQuery().context()); + + taskContextOverridesBuilder + .put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, queryKernelConfig.isDurableStorage()) + .put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, maxParseExceptions) + .put(MultiStageQueryContext.CTX_IS_REINDEX, MSQControllerTask.isReplaceInputDataSourceTask(querySpec)) + .put(MultiStageQueryContext.CTX_MAX_CONCURRENT_STAGES, queryKernelConfig.getMaxConcurrentStages()); + + if (querySpec.getDestination().toSelectDestination() != null) { + taskContextOverridesBuilder.put( + MultiStageQueryContext.CTX_SELECT_DESTINATION, + querySpec.getDestination().toSelectDestination().getName() + ); + } + + // propagate the controller's tags to the worker task for enhanced metrics reporting + @SuppressWarnings("unchecked") + Map tags = (Map) controllerTaskContext.get(DruidMetrics.TAGS); + if (tags != null) { + taskContextOverridesBuilder.put(DruidMetrics.TAGS, tags); + } + + return taskContextOverridesBuilder.build(); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java new file mode 100644 index 000000000000..30bc75282fa4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import org.apache.druid.cli.CliIndexer; +import org.apache.druid.cli.CliPeon; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.Resource; +import org.apache.druid.server.security.ResourceAction; +import org.apache.druid.server.security.ResourceType; + +import java.util.Collections; +import java.util.List; + +/** + * Production implementation of {@link ResourcePermissionMapper} for tasks: {@link CliIndexer} and {@link CliPeon}. + */ +public class IndexerResourcePermissionMapper implements ResourcePermissionMapper +{ + private final String dataSource; + + public IndexerResourcePermissionMapper(String dataSource) + { + this.dataSource = dataSource; + } + + @Override + public List getAdminPermissions() + { + return Collections.singletonList( + new ResourceAction( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ) + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java index 9cc9e4dae745..e645e0e62cd8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java @@ -47,6 +47,7 @@ import org.apache.druid.msq.exec.ControllerContext; import org.apache.druid.msq.exec.ControllerImpl; import org.apache.druid.msq.exec.MSQTasks; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.ExportMSQDestination; @@ -246,20 +247,37 @@ public TaskStatus runTask(final TaskToolbox toolbox) throws Exception final OverlordClient overlordClient = injector.getInstance(OverlordClient.class) .withRetryPolicy(StandardRetryPolicy.unlimited()); final ControllerContext context = new IndexerControllerContext( + this, toolbox, injector, clientFactory, overlordClient ); - controller = new ControllerImpl(this, context); - return controller.run(); + + controller = new ControllerImpl( + this.getId(), + querySpec, + new ResultsContext(getSqlTypeNames(), getSqlResultsContext()), + context + ); + + final TaskReportQueryListener queryListener = new TaskReportQueryListener( + querySpec.getDestination(), + () -> toolbox.getTaskReportFileWriter().openReportOutputStream(getId()), + toolbox.getJsonMapper(), + getId(), + getContext() + ); + + controller.run(queryListener); + return queryListener.getStatusReport().toTaskStatus(getId()); } @Override public void stopGracefully(final TaskConfig taskConfig) { if (controller != null) { - controller.stopGracefully(); + controller.stop(); } } @@ -300,14 +318,15 @@ public static boolean isExport(final MSQSpec querySpec) * Returns true if the task reads from the same table as the destionation. In this case, we would prefer to fail * instead of reading any unused segments to ensure that old data is not read. */ - public static boolean isReplaceInputDataSourceTask(MSQControllerTask task) + public static boolean isReplaceInputDataSourceTask(MSQSpec querySpec) { - return task.getQuerySpec() - .getQuery() - .getDataSource() - .getTableNames() - .stream() - .anyMatch(datasouce -> task.getDataSource().equals(datasouce)); + if (isIngestion(querySpec)) { + final String targetDataSource = ((DataSourceMSQDestination) querySpec.getDestination()).getDataSource(); + final Set sourceTableNames = querySpec.getQuery().getDataSource().getTableNames(); + return sourceTableNames.contains(targetDataSource); + } else { + return false; + } } public static boolean writeResultsToDurableStorage(final MSQSpec querySpec) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java index 55ff6a3876d8..ed32b81f44ef 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java @@ -19,13 +19,15 @@ package org.apache.druid.msq.indexing; -import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; +import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.client.indexing.TaskStatusResponse; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; @@ -34,17 +36,18 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.exec.ControllerContext; -import org.apache.druid.msq.exec.ControllerImpl; import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.MSQTasks; -import org.apache.druid.msq.exec.WorkerManagerClient; +import org.apache.druid.msq.exec.RetryCapableWorkerManager; +import org.apache.druid.msq.exec.WorkerFailureListener; +import org.apache.druid.msq.exec.WorkerStats; import org.apache.druid.msq.indexing.error.MSQException; import org.apache.druid.msq.indexing.error.TaskStartTimeoutFault; import org.apache.druid.msq.indexing.error.TooManyAttemptsForJob; import org.apache.druid.msq.indexing.error.TooManyAttemptsForWorker; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerFailedFault; +import org.apache.druid.rpc.indexing.OverlordClient; import java.time.Duration; import java.util.ArrayList; @@ -56,19 +59,17 @@ import java.util.Map; import java.util.OptionalLong; import java.util.Set; -import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; /** * Like {@link org.apache.druid.indexing.common.task.batch.parallel.TaskMonitor}, but different. */ -public class MSQWorkerTaskLauncher +public class MSQWorkerTaskLauncher implements RetryCapableWorkerManager { private static final Logger log = new Logger(MSQWorkerTaskLauncher.class); private static final long HIGH_FREQUENCY_CHECK_MILLIS = 100; @@ -87,7 +88,7 @@ private enum State private final String controllerTaskId; private final String dataSource; - private final ControllerContext context; + private final OverlordClient overlordClient; private final ExecutorService exec; private final long maxTaskStartDelayMillis; @@ -108,15 +109,19 @@ private enum State @GuardedBy("taskIds") private final List taskIds = new ArrayList<>(); + // Task ID -> worker number. Only set for active workers. + @GuardedBy("taskIds") + private final Map taskIdToWorkerNumber = new HashMap<>(); + // Worker number -> whether the task has fully started up or not. @GuardedBy("taskIds") private final IntSet fullyStartedTasks = new IntOpenHashSet(); - // Mutable state accessed by mainLoop, ControllerImpl, and jetty (/liveReports) threads. + // Mutable state written only by the mainLoop() thread. // Tasks are added here once they are submitted for running, but before they are fully started up. - // taskId -> taskTracker - private final ConcurrentMap taskTrackers = new ConcurrentSkipListMap<>(Comparator.comparingInt( - MSQTasks::workerFromTaskId)); + // Uses a concurrent map because getWorkerStats() reads this map too, and getWorkerStats() can be called by various + // other threads. + private final ConcurrentHashMap taskTrackers = new ConcurrentHashMap<>(); // Set of tasks which are issued a cancel request by the controller. private final Set canceledWorkerTasks = ConcurrentHashMap.newKeySet(); @@ -135,35 +140,31 @@ private enum State private final Set failedInactiveWorkers = ConcurrentHashMap.newKeySet(); private final ConcurrentHashMap> workerToTaskIds = new ConcurrentHashMap<>(); - private final RetryTask retryTask; + private final WorkerFailureListener workerFailureListener; private final AtomicLong recentFullyStartedWorkerTimeInMillis = new AtomicLong(System.currentTimeMillis()); public MSQWorkerTaskLauncher( final String controllerTaskId, final String dataSource, - final ControllerContext context, - final RetryTask retryTask, + final OverlordClient overlordClient, + final WorkerFailureListener workerFailureListener, final Map taskContextOverrides, final long maxTaskStartDelayMillis ) { this.controllerTaskId = controllerTaskId; this.dataSource = dataSource; - this.context = context; + this.overlordClient = overlordClient; this.taskContextOverrides = taskContextOverrides; this.exec = Execs.singleThreaded( "multi-stage-query-task-launcher[" + StringUtils.encodeForFormat(controllerTaskId) + "]-%s" ); - this.retryTask = retryTask; + this.workerFailureListener = workerFailureListener; this.maxTaskStartDelayMillis = maxTaskStartDelayMillis; } - /** - * Launches tasks, blocking until they are all in RUNNING state. Returns a future that resolves successfully when - * all tasks end successfully or are canceled. The returned future resolves to an exception if one of the tasks fails - * without being explicitly canceled, or if something else goes wrong. - */ + @Override public ListenableFuture start() { if (state.compareAndSet(State.NEW, State.STARTED)) { @@ -181,10 +182,7 @@ public ListenableFuture start() return stopFuture; } - /** - * Stops all tasks, blocking until they exit. Returns quietly, no matter whether there was an exception - * associated with the future from {@link #start()} or not. - */ + @Override public void stop(final boolean interrupt) { if (state.compareAndSet(State.NEW, State.STOPPED)) { @@ -221,24 +219,24 @@ public void stop(final boolean interrupt) } // Block until stopped. - waitForWorkerShutdown(); + try { + FutureUtils.getUnchecked(stopFuture, false); + } + catch (Throwable ignored) { + // Suppress. + } } - /** - * Get the list of currently-active tasks. - */ - public List getActiveTasks() + @Override + public List getWorkerIds() { synchronized (taskIds) { return ImmutableList.copyOf(taskIds); } } - /** - * Launch additional tasks, if needed, to bring the size of {@link #taskIds} up to {@code taskCount}. If enough - * tasks are already running, this method does nothing. - */ - public void launchTasksIfNeeded(final int taskCount) throws InterruptedException + @Override + public void launchWorkersIfNeeded(final int taskCount) throws InterruptedException { synchronized (taskIds) { retryInactiveTasksIfNeeded(taskCount); @@ -280,21 +278,13 @@ Set getWorkersToRelaunch() return workersToRelaunch; } - /** - * Queues worker for relaunch. A noop if the worker is already in the queue. - * - * @param workerNumber worker number - */ + @Override public void submitForRelaunch(int workerNumber) { workersToRelaunch.add(workerNumber); } - /** - * Report a worker that failed without active orders. To be retried if it is requried for future stages only. - * - * @param workerNumber worker number - */ + @Override public void reportFailedInactiveWorker(int workerNumber) { synchronized (taskIds) { @@ -302,16 +292,11 @@ public void reportFailedInactiveWorker(int workerNumber) } } - /** - * Blocks the call untill the worker tasks are ready to be contacted for work. - * - * @param workerSet - * @throws InterruptedException - */ - public void waitUntilWorkersReady(Set workerSet) throws InterruptedException + @Override + public void waitForWorkers(Set workerNumbers) throws InterruptedException { synchronized (taskIds) { - while (!fullyStartedTasks.containsAll(workerSet)) { + while (!fullyStartedTasks.containsAll(workerNumbers)) { if (stopFuture.isDone() || stopFuture.isCancelled()) { FutureUtils.getUnchecked(stopFuture, false); throw new ISE("Stopped"); @@ -321,40 +306,30 @@ public void waitUntilWorkersReady(Set workerSet) throws InterruptedExce } } - public void waitForWorkerShutdown() - { - try { - FutureUtils.getUnchecked(stopFuture, false); - } - catch (Throwable ignored) { - // Suppress. - } - } - - /** - * Checks if the controller has canceled the input taskId. This method is used in {@link ControllerImpl} - * to figure out if the worker taskId is canceled by the controller. If yes, the errors from that worker taskId - * are ignored for the error reports. - * - * @return true if task is canceled by the controller, else false - */ + @Override public boolean isTaskCanceledByController(String taskId) { return canceledWorkerTasks.contains(taskId); } + @Override + public int getWorkerNumber(String taskId) + { + return MSQTasks.workerFromTaskId(taskId); + } - public boolean isTaskLatest(String taskId) + @Override + public boolean isWorkerActive(String taskId) { - int worker = MSQTasks.workerFromTaskId(taskId); synchronized (taskIds) { - return taskId.equals(taskIds.get(worker)); + return taskIdToWorkerNumber.get(taskId) != null; } } + @Override public Map> getWorkerStats() { - final Map> workerStats = new TreeMap<>(); + final Int2ObjectMap> workerStats = new Int2ObjectAVLTreeMap<>(); for (Map.Entry taskEntry : taskTrackers.entrySet()) { final TaskTracker taskTracker = taskEntry.getValue(); @@ -393,6 +368,7 @@ private void mainLoop() cleanFailedTasksWhichAreRelaunched(); } catch (Throwable e) { + log.warn(e, "Stopped due to exception in task management loop."); state.set(State.STOPPED); cancelTasksOnStop.set(true); caught = e; @@ -491,9 +467,10 @@ private void runNewTasks() return taskIds; }); - context.workerManager().run(task.getId(), task); + FutureUtils.getUnchecked(overlordClient.runTask(task.getId(), task), true); synchronized (taskIds) { + taskIdToWorkerNumber.put(task.getId(), taskIds.size()); taskIds.add(task.getId()); taskIds.notifyAll(); } @@ -504,7 +481,8 @@ private void runNewTasks() * Returns a pair which contains the number of currently running worker tasks and the number of worker tasks that are * not yet fully started as left and right respectively. */ - public WorkerCount getWorkerTaskCount() + @Override + public WorkerCount getWorkerCount() { synchronized (taskIds) { if (stopFuture.isDone()) { @@ -530,8 +508,8 @@ private void updateTaskTrackersAndTaskIds() } if (!taskStatusesNeeded.isEmpty()) { - final WorkerManagerClient workerManager = context.workerManager(); - final Map statuses = workerManager.statuses(taskStatusesNeeded); + final Map statuses = + FutureUtils.getUnchecked(overlordClient.taskStatuses(taskStatusesNeeded), true); for (Map.Entry statusEntry : statuses.entrySet()) { final String taskId = statusEntry.getKey(); @@ -542,7 +520,13 @@ private void updateTaskTrackersAndTaskIds() if (!status.getStatusCode().isComplete() && tracker.unknownLocation()) { // Look up location if not known. Note: this location is not used to actually contact the task. For that, // we have SpecificTaskServiceLocator. This location is only used to determine if a task has started up. - tracker.setLocation(workerManager.location(taskId)); + final TaskStatusResponse taskStatusResponse = + FutureUtils.getUnchecked(overlordClient.taskStatus(taskId), true); + if (taskStatusResponse.getStatus() != null) { + tracker.setLocation(taskStatusResponse.getStatus().getLocation()); + } else { + tracker.setLocation(TaskLocation.unknown()); + } } if (status.getStatusCode() == TaskState.RUNNING && !tracker.unknownLocation()) { @@ -568,10 +552,7 @@ private void checkForErroneousTasks() { final int numTasks = taskTrackers.size(); - Iterator> taskTrackerIterator = taskTrackers.entrySet().iterator(); - - while (taskTrackerIterator.hasNext()) { - final Map.Entry taskEntry = taskTrackerIterator.next(); + for (Map.Entry taskEntry : taskTrackersByWorkerNumber()) { final String taskId = taskEntry.getKey(); final TaskTracker tracker = taskEntry.getValue(); if (tracker.isRetrying()) { @@ -583,7 +564,7 @@ private void checkForErroneousTasks() final String errorMessage = StringUtils.format("Task [%s] status missing", taskId); log.info(errorMessage + ". Trying to relaunch the worker"); tracker.enableRetrying(); - retryTask.retry( + workerFailureListener.onFailure( tracker.msqWorkerTask, UnknownFault.forMessage(errorMessage) ); @@ -591,7 +572,7 @@ private void checkForErroneousTasks() } else if (tracker.didRunTimeOut(maxTaskStartDelayMillis) && !canceledWorkerTasks.contains(taskId)) { removeWorkerFromFullyStartedWorkers(tracker); throw new MSQException(new TaskStartTimeoutFault( - this.getWorkerTaskCount().getPendingWorkerCount(), + this.getWorkerCount().getPendingWorkerCount(), numTasks + 1, maxTaskStartDelayMillis )); @@ -600,7 +581,10 @@ private void checkForErroneousTasks() TaskStatus taskStatus = tracker.statusRef.get(); log.info("Task[%s] failed because %s. Trying to relaunch the worker", taskId, taskStatus.getErrorMsg()); tracker.enableRetrying(); - retryTask.retry(tracker.msqWorkerTask, new WorkerFailedFault(taskId, taskStatus.getErrorMsg())); + workerFailureListener.onFailure( + tracker.msqWorkerTask, + new WorkerFailedFault(taskId, taskStatus.getErrorMsg()) + ); } } } @@ -658,16 +642,18 @@ private void relaunchTasks() taskIds.notifyAll(); } - context.workerManager().run(relaunchedTask.getId(), relaunchedTask); + FutureUtils.getUnchecked(overlordClient.runTask(relaunchedTask.getId(), relaunchedTask), true); taskHistory.add(relaunchedTask.getId()); synchronized (taskIds) { // replace taskId with the retry taskID for the same worker number + taskIdToWorkerNumber.remove(taskIds.get(toRelaunch.getWorkerNumber())); taskIds.set(toRelaunch.getWorkerNumber(), relaunchedTask.getId()); + taskIdToWorkerNumber.put(relaunchedTask.getId(), toRelaunch.getWorkerNumber()); taskIds.notifyAll(); } - return taskHistory; + return taskHistory; }); iterator.remove(); } @@ -697,14 +683,14 @@ private void shutDownTasks() { cleanFailedTasksWhichAreRelaunched(); - for (final Map.Entry taskEntry : taskTrackers.entrySet()) { + for (final Map.Entry taskEntry : taskTrackersByWorkerNumber()) { final String taskId = taskEntry.getKey(); final TaskTracker tracker = taskEntry.getValue(); if ((!canceledWorkerTasks.contains(taskId)) && (!tracker.isComplete())) { canceledWorkerTasks.add(taskId); - context.workerManager().cancel(taskId); + FutureUtils.getUnchecked(overlordClient.cancelTask(taskId), true); } } } @@ -720,7 +706,7 @@ private void cleanFailedTasksWhichAreRelaunched() try { if (canceledWorkerTasks.add(taskId)) { try { - context.workerManager().cancel(taskId); + FutureUtils.getUnchecked(overlordClient.cancelTask(taskId), true); } catch (Exception ignore) { //ignoring cancellation exception @@ -730,7 +716,6 @@ private void cleanFailedTasksWhichAreRelaunched() finally { tasksToCancel.remove(); } - } } @@ -749,6 +734,17 @@ private boolean allTasksStarted(final int taskCount) return true; } + /** + * Returns entries of {@link #taskTrackers} sorted by worker number. + */ + private List> taskTrackersByWorkerNumber() + { + return taskTrackers.entrySet() + .stream() + .sorted(Comparator.comparing(entry -> entry.getValue().workerNumber)) + .collect(Collectors.toList()); + } + /** * Used by the main loop to decide how often to check task status. */ @@ -885,51 +881,4 @@ public long taskPendingTimeInMs() } } } - - public static class WorkerStats - { - String workerId; - TaskState state; - long duration; - long pendingTimeInMs; - - /** - * For JSON deserialization only - */ - public WorkerStats() - { - } - - public WorkerStats(String workerId, TaskState state, long duration, long pendingTimeInMs) - { - this.workerId = workerId; - this.state = state; - this.duration = duration; - this.pendingTimeInMs = pendingTimeInMs; - } - - @JsonProperty - public String getWorkerId() - { - return workerId; - } - - @JsonProperty - public TaskState getState() - { - return state; - } - - @JsonProperty("durationMs") - public long getDuration() - { - return duration; - } - - @JsonProperty("pendingMs") - public long getPendingTimeInMs() - { - return pendingTimeInMs; - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java new file mode 100644 index 000000000000..4cc4678a58a7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.indexer.report.TaskContextReport; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.indexing.destination.MSQDestination; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; +import java.util.Map; + +/** + * Query listener that writes {@link MSQTaskReport} to an {@link OutputStream}. + * + * This is used so the report can be written one row at a time, as results are being read, as part of the main + * query loop. This allows reports to scale to row counts that cannot be materialized in memory, and allows + * report-writing to be interleaved with query execution when using {@link OutputChannelMode#MEMORY}. + */ +public class TaskReportQueryListener implements QueryListener +{ + private static final String FIELD_TYPE = "type"; + private static final String FIELD_TASK_ID = "taskId"; + private static final String FIELD_PAYLOAD = "payload"; + private static final String FIELD_STATUS = "status"; + private static final String FIELD_STAGES = "stages"; + private static final String FIELD_COUNTERS = "counters"; + private static final String FIELD_RESULTS = "results"; + private static final String FIELD_RESULTS_SIGNATURE = "signature"; + private static final String FIELD_RESULTS_SQL_TYPE_NAMES = "sqlTypeNames"; + private static final String FIELD_RESULTS_RESULTS = "results"; + private static final String FIELD_RESULTS_TRUNCATED = "resultsTruncated"; + + private final long rowsInTaskReport; + private final OutputStreamSupplier reportSink; + private final ObjectMapper jsonMapper; + private final SerializerProvider serializers; + private final String taskId; + private final Map taskContext; + + private JsonGenerator jg; + private long numResults; + private MSQStatusReport statusReport; + + public TaskReportQueryListener( + final MSQDestination destination, + final OutputStreamSupplier reportSink, + final ObjectMapper jsonMapper, + final String taskId, + final Map taskContext + ) + { + this.rowsInTaskReport = destination.getRowsInTaskReport(); + this.reportSink = reportSink; + this.jsonMapper = jsonMapper; + this.serializers = jsonMapper.getSerializerProviderInstance(); + this.taskId = taskId; + this.taskContext = taskContext; + } + + @Override + public boolean readResults() + { + return rowsInTaskReport == MSQDestination.UNLIMITED || rowsInTaskReport > 0; + } + + @Override + public void onResultsStart(List signature, @Nullable List sqlTypeNames) + { + try { + openGenerator(); + + jg.writeObjectFieldStart(FIELD_RESULTS); + writeObjectField(FIELD_RESULTS_SIGNATURE, signature); + if (sqlTypeNames != null) { + writeObjectField(FIELD_RESULTS_SQL_TYPE_NAMES, sqlTypeNames); + } + jg.writeArrayFieldStart(FIELD_RESULTS_RESULTS); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean onResultRow(Object[] row) + { + try { + JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, row); + numResults++; + + if (rowsInTaskReport == MSQDestination.UNLIMITED || numResults < rowsInTaskReport) { + return true; + } else { + jg.writeEndArray(); + jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, true); + jg.writeEndObject(); + return false; + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onResultsComplete() + { + try { + jg.writeEndArray(); + jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, false); + jg.writeEndObject(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + try { + openGenerator(); + statusReport = report.getStatus(); + writeObjectField(FIELD_STATUS, report.getStatus()); + + if (report.getStages() != null) { + writeObjectField(FIELD_STAGES, report.getStages()); + } + + if (report.getCounters() != null) { + writeObjectField(FIELD_COUNTERS, report.getCounters()); + } + + jg.writeEndObject(); // End MSQTaskReportPayload + jg.writeEndObject(); // End MSQTaskReport + jg.writeObjectField(TaskContextReport.REPORT_KEY, new TaskContextReport(taskId, taskContext)); + jg.writeEndObject(); // End report + jg.close(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + public MSQStatusReport getStatusReport() + { + if (statusReport == null) { + throw new ISE("Status report not available"); + } + + return statusReport; + } + + /** + * Initialize {@link #jg}, if it wasn't already set up. Writes the object start marker, too. + */ + private void openGenerator() throws IOException + { + if (jg == null) { + jg = jsonMapper.createGenerator(reportSink.get()); + jg.writeStartObject(); // Start report + jg.writeObjectFieldStart(MSQTaskReport.REPORT_KEY); // Start MSQTaskReport + jg.writeStringField(FIELD_TYPE, MSQTaskReport.REPORT_KEY); + jg.writeStringField(FIELD_TASK_ID, taskId); + jg.writeObjectFieldStart(FIELD_PAYLOAD); // Start MSQTaskReportPayload + } + } + + /** + * Write a field name followed by an object. Unlike {@link JsonGenerator#writeObjectField(String, Object)}, + * this approach avoids the re-creation of a {@link SerializerProvider} for each call. + */ + private void writeObjectField(final String fieldName, final Object value) throws IOException + { + jg.writeFieldName(fieldName); + JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, value); + } + + public interface OutputStreamSupplier + { + OutputStream get() throws IOException; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java index 4be026ac34c2..bf3dd4a6bf14 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/ControllerChatHandler.java @@ -19,179 +19,16 @@ package org.apache.druid.msq.indexing.client; -import org.apache.druid.indexer.report.TaskReport; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.msq.counters.CounterSnapshots; -import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.exec.Controller; -import org.apache.druid.msq.exec.ControllerClient; -import org.apache.druid.msq.indexing.MSQControllerTask; -import org.apache.druid.msq.indexing.MSQTaskList; -import org.apache.druid.msq.indexing.error.MSQErrorReport; -import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.msq.indexing.IndexerResourcePermissionMapper; +import org.apache.druid.msq.rpc.ControllerResource; import org.apache.druid.segment.realtime.firehose.ChatHandler; -import org.apache.druid.segment.realtime.firehose.ChatHandlers; -import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.AuthorizerMapper; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import java.util.List; - -public class ControllerChatHandler implements ChatHandler +public class ControllerChatHandler extends ControllerResource implements ChatHandler { - private final Controller controller; - private final MSQControllerTask task; - private final TaskToolbox toolbox; - - public ControllerChatHandler(TaskToolbox toolbox, Controller controller) - { - this.controller = controller; - this.task = controller.task(); - this.toolbox = toolbox; - } - - /** - * Used by subtasks to post {@link PartialKeyStatisticsInformation} for shuffling stages. - * - * See {@link ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)} - * for the client-side code that calls this API. - */ - @POST - @Path("/partialKeyStatisticsInformation/{queryId}/{stageNumber}/{workerNumber}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostPartialKeyStatistics( - final Object partialKeyStatisticsObject, - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("workerNumber") final int workerNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.updatePartialKeyStatisticsInformation(stageNumber, workerNumber, partialKeyStatisticsObject); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * Used by subtasks to post system errors. Note that the errors are organized by taskId, not by query/stage/worker, - * because system errors are associated with a task rather than a specific query/stage/worker execution context. - * - * See {@link ControllerClient#postWorkerError} for the client-side code that calls this API. - */ - @POST - @Path("/workerError/{taskId}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostWorkerError( - final MSQErrorReport errorReport, - @PathParam("taskId") final String taskId, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.workerError(errorReport); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * Used by subtasks to post system warnings. - * - * See {@link ControllerClient#postWorkerWarning} for the client-side code that calls this API. - */ - @POST - @Path("/workerWarning") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostWorkerWarning( - final List errorReport, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.workerWarning(errorReport); - return Response.status(Response.Status.ACCEPTED).build(); - } - - - /** - * Used by subtasks to post {@link CounterSnapshots} periodically. - * - * See {@link ControllerClient#postCounters} for the client-side code that calls this API. - */ - @POST - @Path("/counters/{taskId}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostCounters( - @PathParam("taskId") final String taskId, - final CounterSnapshotsTree snapshotsTree, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.updateCounters(taskId, snapshotsTree); - return Response.status(Response.Status.OK).build(); - } - - /** - * Used by subtasks to post notifications that their results are ready. - * - * See {@link ControllerClient#postResultsComplete} for the client-side code that calls this API. - */ - @POST - @Path("/resultsComplete/{queryId}/{stageNumber}/{workerNumber}") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpPostResultsComplete( - final Object resultObject, - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("workerNumber") final int workerNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - controller.resultsComplete(queryId, stageNumber, workerNumber, resultObject); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * See {@link ControllerClient#getTaskList} for the client-side code that calls this API. - */ - @GET - @Path("/taskList") - @Produces(MediaType.APPLICATION_JSON) - public Response httpGetTaskList(@Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - - return Response.ok(new MSQTaskList(controller.getTaskIds())).build(); - } - - /** - * See {@link org.apache.druid.indexing.overlord.RemoteTaskRunner#streamTaskReports} for the client-side code that - * calls this API. - */ - @GET - @Path("/liveReports") - @Produces(MediaType.APPLICATION_JSON) - public Response httpGetLiveReports(@Context final HttpServletRequest req) + public ControllerChatHandler(Controller controller, String dataSource, AuthorizerMapper authorizerMapper) { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - final TaskReport.ReportMap reports = controller.liveReports(); - if (reports == null) { - return Response.status(Response.Status.NOT_FOUND).build(); - } - return Response.ok(reports).build(); + super(controller, new IndexerResourcePermissionMapper(dataSource), authorizerMapper); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java index 493cbeb62424..81303eb43848 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java @@ -80,6 +80,22 @@ public void postPartialKeyStatistics( ); } + @Override + public void postDoneReadingInput(StageId stageId, int workerNumber) throws IOException + { + final String path = StringUtils.format( + "/doneReadingInput/%s/%d/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber(), + workerNumber + ); + + doRequest( + new RequestBuilder(HttpMethod.POST, path), + IgnoreHttpResponseHandler.INSTANCE + ); + } + @Override public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) throws IOException { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java index af089a296006..e9b4a370b241 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerClient.java @@ -19,32 +19,11 @@ package org.apache.druid.msq.indexing.client; -import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.SettableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; -import org.apache.druid.common.guava.FutureUtils; -import org.apache.druid.frame.channel.ReadableByteChunksFrameChannel; -import org.apache.druid.frame.file.FrameFileHttpResponseHandler; -import org.apache.druid.frame.file.FrameFilePartialFetch; -import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.java.util.common.Pair; -import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.concurrent.Execs; -import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; -import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder; -import org.apache.druid.msq.counters.CounterSnapshotsTree; -import org.apache.druid.msq.exec.WorkerClient; -import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.kernel.WorkOrder; -import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import org.apache.druid.rpc.IgnoreHttpResponseHandler; -import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.msq.indexing.MSQWorkerTask; +import org.apache.druid.msq.rpc.BaseWorkerClientImpl; import org.apache.druid.rpc.ServiceClient; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.StandardRetryPolicy; @@ -52,10 +31,8 @@ import org.apache.druid.rpc.indexing.SpecificTaskRetryPolicy; import org.apache.druid.rpc.indexing.SpecificTaskServiceLocator; import org.apache.druid.utils.CloseableUtils; -import org.jboss.netty.handler.codec.http.HttpMethod; -import javax.annotation.Nonnull; -import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; import java.io.Closeable; import java.io.IOException; import java.util.HashMap; @@ -63,11 +40,13 @@ import java.util.Map; import java.util.stream.Collectors; -public class IndexerWorkerClient implements WorkerClient +/** + * Worker client for {@link MSQWorkerTask}. + */ +public class IndexerWorkerClient extends BaseWorkerClientImpl { private final ServiceClientFactory clientFactory; private final OverlordClient overlordClient; - private final ObjectMapper jsonMapper; @GuardedBy("clientMap") private final Map> clientMap = new HashMap<>(); @@ -78,202 +57,9 @@ public IndexerWorkerClient( final ObjectMapper jsonMapper ) { + super(jsonMapper, MediaType.APPLICATION_JSON); this.clientFactory = clientFactory; this.overlordClient = overlordClient; - this.jsonMapper = jsonMapper; - } - - - @Nonnull - public static String getStagePartitionPath(StageId stageId, int partitionNumber) - { - return StringUtils.format( - "/channels/%s/%d/%d", - StringUtils.urlEncode(stageId.getQueryId()), - stageId.getStageNumber(), - partitionNumber - ); - } - - @Override - public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workOrder) - { - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, "/workOrder") - .jsonContent(jsonMapper, workOrder), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - @Override - public ListenableFuture fetchClusterByStatisticsSnapshot( - String workerTaskId, - String queryId, - int stageNumber - ) - { - String path = StringUtils.format( - "/keyStatistics/%s/%d", - StringUtils.urlEncode(queryId), - stageNumber - ); - - return FutureUtils.transform( - getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path), - new BytesFullResponseHandler() - ), - holder -> deserialize(holder, new TypeReference() - { - }) - ); - } - - @Override - public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( - String workerTaskId, - String queryId, - int stageNumber, - long timeChunk - ) - { - String path = StringUtils.format( - "/keyStatisticsForTimeChunk/%s/%d/%d", - StringUtils.urlEncode(queryId), - stageNumber, - timeChunk - ); - - return FutureUtils.transform( - getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path), - new BytesFullResponseHandler() - ), - holder -> deserialize(holder, new TypeReference() - { - }) - ); - } - - @Override - public ListenableFuture postResultPartitionBoundaries( - String workerTaskId, - StageId stageId, - ClusterByPartitions partitionBoundaries - ) - { - final String path = StringUtils.format( - "/resultPartitionBoundaries/%s/%d", - StringUtils.urlEncode(stageId.getQueryId()), - stageId.getStageNumber() - ); - - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path) - .jsonContent(jsonMapper, partitionBoundaries), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - /** - * Client-side method for {@link WorkerChatHandler#httpPostCleanupStage}. - */ - @Override - public ListenableFuture postCleanupStage( - final String workerTaskId, - final StageId stageId - ) - { - final String path = StringUtils.format( - "/cleanupStage/%s/%d", - StringUtils.urlEncode(stageId.getQueryId()), - stageId.getStageNumber() - ); - - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, path), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - @Override - public ListenableFuture postFinish(String workerTaskId) - { - return getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.POST, "/finish"), - IgnoreHttpResponseHandler.INSTANCE - ); - } - - @Override - public ListenableFuture getCounters(String workerTaskId) - { - return FutureUtils.transform( - getClient(workerTaskId).asyncRequest( - new RequestBuilder(HttpMethod.GET, "/counters"), - new BytesFullResponseHandler() - ), - holder -> deserialize(holder, new TypeReference() - { - }) - ); - } - - private static final Logger log = new Logger(IndexerWorkerClient.class); - - @Override - public ListenableFuture fetchChannelData( - String workerTaskId, - StageId stageId, - int partitionNumber, - long offset, - ReadableByteChunksFrameChannel channel - ) - { - final ServiceClient client = getClient(workerTaskId); - final String path = getStagePartitionPath(stageId, partitionNumber); - - final SettableFuture retVal = SettableFuture.create(); - final ListenableFuture clientFuture = - client.asyncRequest( - new RequestBuilder(HttpMethod.GET, StringUtils.format("%s?offset=%d", path, offset)) - .header(HttpHeaders.ACCEPT_ENCODING, "identity"), // Data is compressed at app level - new FrameFileHttpResponseHandler(channel) - ); - - Futures.addCallback( - clientFuture, - new FutureCallback() - { - @Override - public void onSuccess(FrameFilePartialFetch partialFetch) - { - if (partialFetch.isExceptionCaught()) { - // Exception while reading channel. Recoverable. - log.noStackTrace().info( - partialFetch.getExceptionCaught(), - "Encountered exception while reading channel [%s]", - channel.getId() - ); - } - - // Empty fetch means this is the last fetch for the channel. - partialFetch.backpressureFuture().addListener( - () -> retVal.set(partialFetch.isLastFetch()), - Execs.directExecutor() - ); - } - - @Override - public void onFailure(Throwable t) - { - retVal.setException(t); - } - }, - MoreExecutors.directExecutor() - ); - - return retVal; } @Override @@ -291,36 +77,22 @@ public void close() throws IOException } } - private ServiceClient getClient(final String workerTaskId) + @Override + protected ServiceClient getClient(final String workerId) { synchronized (clientMap) { return clientMap.computeIfAbsent( - workerTaskId, + workerId, id -> { final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(id, overlordClient); final ServiceClient client = clientFactory.makeClient( id, locator, - new SpecificTaskRetryPolicy(workerTaskId, StandardRetryPolicy.unlimitedWithoutRetryLogging()) + new SpecificTaskRetryPolicy(workerId, StandardRetryPolicy.unlimitedWithoutRetryLogging()) ); return Pair.of(client, locator); } ).lhs; } } - - /** - * Deserialize a {@link BytesFullResponseHolder} as JSON. - *

- * It would be reasonable to move this to {@link BytesFullResponseHolder} itself, or some shared utility class. - */ - private T deserialize(final BytesFullResponseHolder bytesHolder, final TypeReference typeReference) - { - try { - return jsonMapper.readValue(bytesHolder.getContent(), typeReference); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java index 0854582a733c..ea3072bfe45a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DataSourceMSQDestination.java @@ -133,6 +133,18 @@ public boolean isReplaceTimeChunks() return replaceTimeChunks != null; } + @Override + public long getRowsInTaskReport() + { + return 0; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return null; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java index e522243b60d2..88fe5f58e5a1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/DurableStorageMSQDestination.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.destination; import com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.querykit.ShuffleSpecFactories; import org.apache.druid.msq.querykit.ShuffleSpecFactory; @@ -63,4 +64,16 @@ public Optional getDestinationResource() { return Optional.of(new Resource(MSQControllerTask.DUMMY_DATASOURCE_FOR_SELECT, ResourceType.DATASOURCE)); } + + @Override + public long getRowsInTaskReport() + { + return Limits.MAX_SELECT_RESULT_ROWS; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return MSQSelectDestination.DURABLESTORAGE; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java index 14ac0ce4c2e8..d6a78def63a8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/ExportMSQDestination.java @@ -64,6 +64,18 @@ public ResultFormat getResultFormat() return resultFormat; } + @Override + public long getRowsInTaskReport() + { + return 0; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return MSQSelectDestination.EXPORT; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java index 39460b15194c..ad7878f049a1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQDestination.java @@ -21,9 +21,11 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.msq.indexing.TaskReportQueryListener; import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.server.security.Resource; +import javax.annotation.Nullable; import java.util.Optional; @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @@ -35,7 +37,31 @@ }) public interface MSQDestination { + /** + * Returned by {@link #getRowsInTaskReport()} when an unlimited number of rows should be included in the task report. + */ + long UNLIMITED = -1; + + /** + * Shuffle spec for the final stage, which writes results to the destination. + */ ShuffleSpecFactory getShuffleSpecFactory(int targetSize); + /** + * Return the resource for this destination. Used for security checks. + */ Optional getDestinationResource(); + + /** + * Number of rows to include in the task report when using {@link TaskReportQueryListener}. Zero means do not + * include results in the report at all. {@link #UNLIMITED} means include an unlimited number of rows. + */ + long getRowsInTaskReport(); + + /** + * Return the {@link MSQSelectDestination} that corresponds to this destination. Returns null if this is not a + * SELECT destination (for example, returns null for {@link DataSourceMSQDestination}). + */ + @Nullable + MSQSelectDestination toSelectDestination(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java index e32705462470..0d21bdbe0c2f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/MSQSelectDestination.java @@ -22,35 +22,31 @@ import com.fasterxml.jackson.annotation.JsonValue; /** - * Determines the destination for results of select queries. + * Determines the destination for results of select queries. Convertible to and from {@link MSQDestination} in a limited + * way, without as many options. Provided directly by end users in query context. */ public enum MSQSelectDestination { /** * Writes all the results directly to the report. */ - TASKREPORT("taskReport", false), + TASKREPORT("taskReport"), + /** - * Writes all the results as files in a specified format to an external location outside druid. + * Writes all the results as files in a specified format to an external location outside Druid. */ - EXPORT("export", false), + EXPORT("export"), + /** * Writes the results as frame files to durable storage. Task report can be truncated to a preview. */ - DURABLESTORAGE("durableStorage", true); + DURABLESTORAGE("durableStorage"); private final String name; - private final boolean shouldTruncateResultsInTaskReport; - - public boolean shouldTruncateResultsInTaskReport() - { - return shouldTruncateResultsInTaskReport; - } - MSQSelectDestination(String name, boolean shouldTruncateResultsInTaskReport) + MSQSelectDestination(String name) { this.name = name; - this.shouldTruncateResultsInTaskReport = shouldTruncateResultsInTaskReport; } @JsonValue @@ -58,13 +54,4 @@ public String getName() { return name; } - - @Override - public String toString() - { - return "MSQSelectDestination{" + - "name='" + name + '\'' + - ", shouldTruncateResultsInTaskReport=" + shouldTruncateResultsInTaskReport + - '}'; - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java index 3f199255ac76..dadc40048b66 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/destination/TaskReportMSQDestination.java @@ -61,4 +61,16 @@ public Optional getDestinationResource() { return Optional.of(new Resource(MSQControllerTask.DUMMY_DATASOURCE_FOR_SELECT, ResourceType.DATASOURCE)); } + + @Override + public long getRowsInTaskReport() + { + return UNLIMITED; + } + + @Override + public MSQSelectDestination toSelectDestination() + { + return MSQSelectDestination.TASKREPORT; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java index ed30179306ad..5c80f065eef3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java @@ -47,8 +47,9 @@ public NotEnoughMemoryFault( { super( CODE, - "Not enough memory. Required at least %,d bytes. (total = %,d bytes; usable = %,d bytes; server workers = %,d; server threads = %,d). Increase JVM memory with the -xmx option" - + (serverWorkers > 1 ? " or reduce number of server workers" : ""), + "Not enough memory. Required at least %,d bytes. (total = %,d bytes; usable = %,d bytes; " + + "worker capacity = %,d; processing threads = %,d). Increase JVM memory with the -Xmx option" + + (serverWorkers > 1 ? " or reduce worker capacity on this server" : ""), suggestedServerMemory, serverMemory, usableMemory, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java index b96ce469145e..0479b2959554 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java @@ -25,23 +25,14 @@ import com.google.common.base.Preconditions; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.common.config.Configs; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; -import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.exec.Limits; -import org.apache.druid.msq.indexing.destination.MSQSelectDestination; -import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.segment.column.ColumnType; import javax.annotation.Nullable; -import java.util.ArrayList; import java.util.List; import java.util.Objects; public class MSQResultsReport { - private static final Logger log = new Logger(MSQResultsReport.class); /** * Like {@link org.apache.druid.segment.column.RowSignature}, but allows duplicate column names for compatibility * with SQL (which also allows duplicate column names in query results). @@ -49,72 +40,21 @@ public class MSQResultsReport private final List signature; @Nullable private final List sqlTypeNames; - private final Yielder resultYielder; + private final List results; private final boolean resultsTruncated; - public MSQResultsReport( - final List signature, - @Nullable final List sqlTypeNames, - final Yielder resultYielder, - @Nullable Boolean resultsTruncated - ) - { - this.signature = Preconditions.checkNotNull(signature, "signature"); - this.sqlTypeNames = sqlTypeNames; - this.resultYielder = Preconditions.checkNotNull(resultYielder, "resultYielder"); - this.resultsTruncated = Configs.valueOrDefault(resultsTruncated, false); - } - - /** - * Method that enables Jackson deserialization. - */ @JsonCreator - static MSQResultsReport fromJson( + public MSQResultsReport( @JsonProperty("signature") final List signature, @JsonProperty("sqlTypeNames") @Nullable final List sqlTypeNames, @JsonProperty("results") final List results, @JsonProperty("resultsTruncated") final Boolean resultsTruncated ) { - return new MSQResultsReport(signature, sqlTypeNames, Yielders.each(Sequences.simple(results)), resultsTruncated); - } - - public static MSQResultsReport createReportAndLimitRowsIfNeeded( - final List signature, - @Nullable final List sqlTypeNames, - Yielder resultYielder, - MSQSelectDestination selectDestination - ) - { - List results = new ArrayList<>(); - long rowCount = 0; - int factor = 1; - while (!resultYielder.isDone()) { - results.add(resultYielder.get()); - resultYielder = resultYielder.next(null); - ++rowCount; - if (selectDestination.shouldTruncateResultsInTaskReport() && rowCount >= Limits.MAX_SELECT_RESULT_ROWS) { - break; - } - if (rowCount % (factor * Limits.MAX_SELECT_RESULT_ROWS) == 0) { - log.warn( - "Task report is getting too large with %d rows. Large task reports can cause the controller to go out of memory. " - + "Consider using the 'limit %d' clause in your query to reduce the number of rows in the result. " - + "If you require all the results, consider setting [%s=%s] in the query context which will allow you to fetch large result sets.", - rowCount, - Limits.MAX_SELECT_RESULT_ROWS, - MultiStageQueryContext.CTX_SELECT_DESTINATION, - MSQSelectDestination.DURABLESTORAGE.getName() - ); - factor = factor < 32 ? factor * 2 : 32; - } - } - return new MSQResultsReport( - signature, - sqlTypeNames, - Yielders.each(Sequences.simple(results)), - !resultYielder.isDone() - ); + this.signature = Preconditions.checkNotNull(signature, "signature"); + this.sqlTypeNames = sqlTypeNames; + this.results = Preconditions.checkNotNull(results, "results"); + this.resultsTruncated = Configs.valueOrDefault(resultsTruncated, false); } @JsonProperty("signature") @@ -132,9 +72,9 @@ public List getSqlTypeNames() } @JsonProperty("results") - public Yielder getResultYielder() + public List getResults() { - return resultYielder; + return results; } @JsonProperty("resultsTruncated") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java index 422d8235fe20..76a077a3ebe0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStagesReport.java @@ -24,7 +24,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.ShuffleKind; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.controller.ControllerStagePhase; import org.joda.time.DateTime; @@ -52,7 +54,8 @@ public static MSQStagesReport create( final Map stagePhaseMap, final Map stageRuntimeMap, final Map stageWorkerCountMap, - final Map stagePartitionCountMap + final Map stagePartitionCountMap, + final Map stageOutputChannelModeMap ) { final List stages = new ArrayList<>(); @@ -76,6 +79,8 @@ public static MSQStagesReport create( stagePhaseMap.get(stageNumber), workerCount, partitionCount, + stageDef.doesShuffle() ? stageDef.getShuffleSpec().kind() : null, + stageOutputChannelModeMap.get(stageNumber), stageStartTime, stageDuration ); @@ -126,6 +131,8 @@ public static class Stage private final ControllerStagePhase phase; private final int workerCount; private final int partitionCount; + private final ShuffleKind shuffleKind; + private final OutputChannelMode outputChannelMode; private final DateTime startTime; private final long duration; @@ -136,7 +143,9 @@ private Stage( @JsonProperty("phase") @Nullable final ControllerStagePhase phase, @JsonProperty("workerCount") final int workerCount, @JsonProperty("partitionCount") final int partitionCount, - @JsonProperty("startTime") @Nullable final DateTime startTime, + @JsonProperty("shuffle") final ShuffleKind shuffleKind, + @JsonProperty("output") final OutputChannelMode outputChannelMode, + @JsonProperty("startTime")@Nullable final DateTime startTime, @JsonProperty("duration") final long duration ) { @@ -145,6 +154,8 @@ private Stage( this.phase = phase; this.workerCount = workerCount; this.partitionCount = partitionCount; + this.shuffleKind = shuffleKind; + this.outputChannelMode = outputChannelMode; this.startTime = startTime; this.duration = duration; } @@ -184,6 +195,20 @@ public int getPartitionCount() return partitionCount; } + @JsonProperty("shuffle") + @JsonInclude(JsonInclude.Include.NON_NULL) + public ShuffleKind getShuffleKind() + { + return shuffleKind; + } + + @JsonProperty("output") + @JsonInclude(JsonInclude.Include.NON_NULL) + public OutputChannelMode getOutputChannelMode() + { + return outputChannelMode; + } + @JsonProperty("sort") @JsonInclude(JsonInclude.Include.NON_DEFAULT) public boolean isSorting() diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java index eca8998f865c..8bab9e0832bd 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java @@ -24,9 +24,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.indexer.TaskState; +import org.apache.druid.indexer.TaskStatus; import org.apache.druid.msq.exec.SegmentLoadStatusFetcher; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; +import org.apache.druid.msq.exec.WorkerStats; import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.joda.time.DateTime; import javax.annotation.Nullable; @@ -50,7 +52,7 @@ public class MSQStatusReport private final long durationMs; - private final Map> workerStats; + private final Map> workerStats; private final int pendingTasks; @@ -69,10 +71,11 @@ public MSQStatusReport( @JsonProperty("warnings") Collection warningReports, @JsonProperty("startTime") @Nullable DateTime startTime, @JsonProperty("durationMs") long durationMs, - @JsonProperty("workers") Map> workerStats, + @JsonProperty("workers") Map> workerStats, @JsonProperty("pendingTasks") int pendingTasks, @JsonProperty("runningTasks") int runningTasks, - @JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus, + @JsonProperty("segmentLoadWaiterStatus") @Nullable + SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus, @JsonProperty("segmentReport") @Nullable MSQSegmentReport segmentReport ) { @@ -136,7 +139,7 @@ public long getDurationMs() } @JsonProperty("workers") - public Map> getWorkerStats() + public Map> getWorkerStats() { return workerStats; } @@ -157,6 +160,22 @@ public MSQSegmentReport getSegmentReport() return segmentReport; } + /** + * Returns a {@link TaskStatus} appropriate for this status report. + */ + public TaskStatus toTaskStatus(final String taskId) + { + if (status == TaskState.SUCCESS) { + return TaskStatus.success(taskId); + } else { + // Error report is nonnull when status code != SUCCESS. Use that message. + return TaskStatus.failure( + taskId, + MSQFaultUtils.generateMessageWithErrorCode(errorReport.getFault()) + ); + } + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java index 111cb5aa83a3..bf00c9434df2 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQTaskReportPayload.java @@ -28,6 +28,11 @@ public class MSQTaskReportPayload { + public static final String FIELD_STATUS = "status"; + public static final String FIELD_STAGES = "stages"; + public static final String FIELD_COUNTERS = "counters"; + public static final String FIELD_RESULTS = "results"; + private final MSQStatusReport status; @Nullable @@ -41,10 +46,10 @@ public class MSQTaskReportPayload @JsonCreator public MSQTaskReportPayload( - @JsonProperty("status") MSQStatusReport status, - @JsonProperty("stages") @Nullable MSQStagesReport stages, - @JsonProperty("counters") @Nullable CounterSnapshotsTree counters, - @JsonProperty("results") @Nullable MSQResultsReport results + @JsonProperty(FIELD_STATUS) MSQStatusReport status, + @JsonProperty(FIELD_STAGES) @Nullable MSQStagesReport stages, + @JsonProperty(FIELD_COUNTERS) @Nullable CounterSnapshotsTree counters, + @JsonProperty(FIELD_RESULTS) @Nullable MSQResultsReport results ) { this.status = status; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java index 074d1a1c0489..ff1808e463c8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicer.java @@ -22,7 +22,8 @@ import java.util.List; /** - * Slices {@link InputSpec} into {@link InputSlice} on the controller. + * Slices {@link InputSpec} into {@link InputSlice} on the controller. Each slice is assigned to a single worker, and + * the slice number equals the worker number. */ public interface InputSpecSlicer { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java index 24b5cc1c5259..8accf1ec1e90 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSpecSlicerFactory.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.input; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.stage.ReadablePartitions; import org.apache.druid.msq.input.stage.StageInputSpecSlicer; @@ -32,5 +33,8 @@ */ public interface InputSpecSlicerFactory { - InputSpecSlicer makeSlicer(Int2ObjectMap stagePartitionsMap); + InputSpecSlicer makeSlicer( + Int2ObjectMap stagePartitionsMap, + Int2ObjectMap stageOutputChannelModeMap + ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java index eaf47a5df0dd..2c32d0e9ec0e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSlice.java @@ -23,8 +23,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; +import javax.annotation.Nullable; import java.util.Objects; /** @@ -38,14 +40,19 @@ public class StageInputSlice implements InputSlice private final int stage; private final ReadablePartitions partitions; + @Nullable // May be null when created by older controllers + private final OutputChannelMode outputChannelMode; + @JsonCreator public StageInputSlice( @JsonProperty("stage") int stageNumber, - @JsonProperty("partitions") ReadablePartitions partitions + @JsonProperty("partitions") ReadablePartitions partitions, + @JsonProperty("output") OutputChannelMode outputChannelMode ) { this.stage = stageNumber; this.partitions = Preconditions.checkNotNull(partitions, "partitions"); + this.outputChannelMode = outputChannelMode; } @JsonProperty("stage") @@ -60,6 +67,13 @@ public ReadablePartitions getPartitions() return partitions; } + @JsonProperty("output") + @Nullable // May be null when created by older controllers + public OutputChannelMode getOutputChannelMode() + { + return outputChannelMode; + } + @Override public int fileCount() { @@ -76,21 +90,24 @@ public boolean equals(Object o) return false; } StageInputSlice that = (StageInputSlice) o; - return stage == that.stage && Objects.equals(partitions, that.partitions); + return stage == that.stage + && Objects.equals(partitions, that.partitions) + && outputChannelMode == that.outputChannelMode; } @Override public int hashCode() { - return Objects.hash(stage, partitions); + return Objects.hash(stage, partitions, outputChannelMode); } @Override public String toString() { - return "StageInputSpec{" + + return "StageInputSlice{" + "stage=" + stage + ", partitions=" + partitions + + ", outputChannelMode=" + outputChannelMode + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java index ad41b4234e85..f3b5d23ae4f0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/StageInputSpecSlicer.java @@ -21,6 +21,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; @@ -36,9 +37,16 @@ public class StageInputSpecSlicer implements InputSpecSlicer // Stage number -> partitions for that stage private final Int2ObjectMap stagePartitionsMap; - public StageInputSpecSlicer(final Int2ObjectMap stagePartitionsMap) + // Stage number -> output mode for that stage + private final Int2ObjectMap stageOutputChannelModeMap; + + public StageInputSpecSlicer( + final Int2ObjectMap stagePartitionsMap, + final Int2ObjectMap stageOutputChannelModeMap + ) { this.stagePartitionsMap = stagePartitionsMap; + this.stageOutputChannelModeMap = stageOutputChannelModeMap; } @Override @@ -53,9 +61,14 @@ public List sliceStatic(InputSpec inputSpec, int maxNumSlices) final StageInputSpec stageInputSpec = (StageInputSpec) inputSpec; final ReadablePartitions stagePartitions = stagePartitionsMap.get(stageInputSpec.getStageNumber()); + final OutputChannelMode outputChannelMode = stageOutputChannelModeMap.get(stageInputSpec.getStageNumber()); if (stagePartitions == null) { - throw new ISE("Stage [%d] not available", stageInputSpec.getStageNumber()); + throw new ISE("Stage[%d] output partitions not available", stageInputSpec.getStageNumber()); + } + + if (outputChannelMode == null) { + throw new ISE("Stage[%d] output mode not available", stageInputSpec.getStageNumber()); } // Decide how many workers to use, and assign inputs. @@ -63,7 +76,13 @@ public List sliceStatic(InputSpec inputSpec, int maxNumSlices) final List retVal = new ArrayList<>(); for (final ReadablePartitions partitions : workerPartitions) { - retVal.add(new StageInputSlice(stageInputSpec.getStageNumber(), partitions)); + retVal.add( + new StageInputSlice( + stageInputSpec.getStageNumber(), + partitions, + outputChannelMode + ) + ); } return retVal; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java index 7e93324ce68d..916dd3c1db38 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/table/TableInputSpecSlicer.java @@ -22,19 +22,31 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; +import org.apache.druid.client.ImmutableSegmentLoadInfo; +import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; +import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.SegmentSource; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; import org.apache.druid.msq.input.NilInputSlice; import org.apache.druid.msq.input.SlicerUtils; -import org.apache.druid.msq.querykit.DataSegmentTimelineView; import org.apache.druid.query.filter.DimFilterUtils; import org.apache.druid.server.coordination.DruidServerMetadata; import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.SegmentTimeline; import org.apache.druid.timeline.TimelineLookup; +import org.apache.druid.timeline.VersionedIntervalTimeline; import org.joda.time.Interval; +import javax.annotation.Nullable; +import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -46,15 +58,25 @@ import java.util.stream.StreamSupport; /** - * Slices {@link TableInputSpec} into {@link SegmentsInputSlice}. + * Slices {@link TableInputSpec} into {@link SegmentsInputSlice} in tasks. */ public class TableInputSpecSlicer implements InputSpecSlicer { - private final DataSegmentTimelineView timelineView; + private static final Logger log = new Logger(TableInputSpecSlicer.class); - public TableInputSpecSlicer(DataSegmentTimelineView timelineView) + private final CoordinatorClient coordinatorClient; + private final TaskActionClient taskActionClient; + private final SegmentSource includeSegmentSource; + + public TableInputSpecSlicer( + CoordinatorClient coordinatorClient, + TaskActionClient taskActionClient, + SegmentSource includeSegmentSource + ) { - this.timelineView = timelineView; + this.coordinatorClient = coordinatorClient; + this.taskActionClient = taskActionClient; + this.includeSegmentSource = includeSegmentSource; } @Override @@ -128,7 +150,7 @@ public List sliceDynamic( private Set getPrunedSegmentSet(final TableInputSpec tableInputSpec) { final TimelineLookup timeline = - timelineView.getTimeline(tableInputSpec.getDataSource(), tableInputSpec.getIntervals()).orElse(null); + getTimeline(tableInputSpec.getDataSource(), tableInputSpec.getIntervals()); if (timeline == null) { return Collections.emptySet(); @@ -159,6 +181,87 @@ private Set getPrunedSegmentSet(final TableInputSpec ta } } + @Nullable + private VersionedIntervalTimeline getTimeline( + final String dataSource, + final List intervals + ) + { + final boolean includeRealtime = SegmentSource.shouldQueryRealtimeServers(includeSegmentSource); + final Iterable realtimeAndHistoricalSegments; + + // Fetch the realtime segments and segments loaded on the historical. Do this first so that we don't miss any + // segment if they get handed off between the two calls. Segments loaded on historicals are deduplicated below, + // since we are only interested in realtime segments for now. + if (includeRealtime) { + realtimeAndHistoricalSegments = coordinatorClient.fetchServerViewSegments(dataSource, intervals); + } else { + realtimeAndHistoricalSegments = ImmutableList.of(); + } + + // Fetch all published, used segments (all non-realtime segments) from the metadata store. + // If the task is operating with a REPLACE lock, + // any segment created after the lock was acquired for its interval will not be considered. + final Collection publishedUsedSegments; + try { + // Additional check as the task action does not accept empty intervals + if (intervals.isEmpty()) { + publishedUsedSegments = Collections.emptySet(); + } else { + publishedUsedSegments = + taskActionClient.submit(new RetrieveUsedSegmentsAction(dataSource, intervals)); + } + } + catch (IOException e) { + throw new MSQException(e, UnknownFault.forException(e)); + } + + int realtimeCount = 0; + + // Deduplicate segments, giving preference to published used segments. + // We do this so that if any segments have been handed off in between the two metadata calls above, + // we directly fetch it from deep storage. + Set unifiedSegmentView = new HashSet<>(publishedUsedSegments); + + // Iterate over the realtime segments and segments loaded on the historical + for (ImmutableSegmentLoadInfo segmentLoadInfo : realtimeAndHistoricalSegments) { + Set servers = segmentLoadInfo.getServers(); + // Filter out only realtime servers. We don't want to query historicals for now, but we can in the future. + // This check can be modified then. + Set realtimeServerMetadata + = servers.stream() + .filter(druidServerMetadata -> includeSegmentSource.getUsedServerTypes() + .contains(druidServerMetadata.getType()) + ) + .collect(Collectors.toSet()); + if (!realtimeServerMetadata.isEmpty()) { + realtimeCount += 1; + DataSegmentWithLocation dataSegmentWithLocation = new DataSegmentWithLocation( + segmentLoadInfo.getSegment(), + realtimeServerMetadata + ); + unifiedSegmentView.add(dataSegmentWithLocation); + } else { + // We don't have any segments of the required segment source, ignore the segment + } + } + + if (includeRealtime) { + log.info( + "Fetched total [%d] segments from coordinator: [%d] from metadata stoure, [%d] from server view", + unifiedSegmentView.size(), + publishedUsedSegments.size(), + realtimeCount + ); + } + + if (unifiedSegmentView.isEmpty()) { + return null; + } else { + return SegmentTimeline.forSegments(unifiedSegmentView); + } + } + private static List makeSlices( final TableInputSpec tableInputSpec, final List> assignments @@ -206,7 +309,8 @@ private static List createWeightedSegmentSet(List new HashSet<>()); serverVsSegmentsMap.get(druidServerMetadata).add(dataSegmentWithInterval); @@ -286,7 +390,8 @@ public DataServerRequestDescriptor toDataServerRequestDescriptor() { return new DataServerRequestDescriptor( serverMetadata, - segments.stream().map(DataSegmentWithInterval::toRichSegmentDescriptor).collect(Collectors.toList())); + segments.stream().map(DataSegmentWithInterval::toRichSegmentDescriptor).collect(Collectors.toList()) + ); } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java index e773fcb87a97..8be2108a57a4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortMaxCountShuffleSpec.java @@ -119,7 +119,11 @@ public ClusterBy clusterBy() @Override public int partitionCount() { - throw new ISE("Number of partitions not known for [%s].", kind()); + if (maxPartitions == 1) { + return 1; + } else { + throw new ISE("Number of partitions not known for [%s] with maxPartitions[%d].", kind(), maxPartitions); + } } @JsonProperty("partitions") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java index 4d608c4fbe43..7f0878da0393 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/GlobalSortShuffleSpec.java @@ -35,6 +35,12 @@ public interface GlobalSortShuffleSpec extends ShuffleSpec */ boolean mustGatherResultKeyStatistics(); + /** + * Whether the {@link ClusterByStatisticsCollector} for this stage collects keys in aggregating mode or + * non-aggregating mode. + */ + boolean doesAggregate(); + /** * Generates a set of partitions based on the provided statistics. * diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java index fc453d76635b..fbc39fc672c3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/HashShuffleSpec.java @@ -59,16 +59,11 @@ public ClusterBy clusterBy() return clusterBy; } - @Override - public boolean doesAggregate() - { - return false; - } - @Override @JsonProperty("partitions") public int partitionCount() { return numPartitions; } + } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java index 6fbe16b6740f..b29d41c336ae 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/MixShuffleSpec.java @@ -53,12 +53,6 @@ public ClusterBy clusterBy() return ClusterBy.none(); } - @Override - public boolean doesAggregate() - { - return false; - } - @Override public int partitionCount() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java index 553e119131d5..64f27f2fddc0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/QueryDefinition.java @@ -96,14 +96,14 @@ static QueryDefinition create(@JsonProperty("stages") final List stageBuilders = new ArrayList<>(); /** - * Package-private: callers should use {@link QueryDefinition#builder()}. + * Package-private: callers should use {@link QueryDefinition#builder(String)}. */ - QueryDefinitionBuilder() + QueryDefinitionBuilder(final String queryId) { - } - - public QueryDefinitionBuilder queryId(final String queryId) - { - this.queryId = Preconditions.checkNotNull(queryId, "queryId"); - return this; + this.queryId = queryId; } public QueryDefinitionBuilder add(final StageDefinitionBuilder stageBuilder) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java index ac3bb99273e7..b1ae27e87166 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleKind.java @@ -19,19 +19,23 @@ package org.apache.druid.msq.kernel; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import org.apache.druid.java.util.common.IAE; + public enum ShuffleKind { /** * Put all data in a single partition, with no sorting and no statistics gathering. */ - MIX(false, false), + MIX("mix", false, false), /** * Partition using hash codes, with no sorting. * * This kind of shuffle supports pipelining: producer and consumer stages can run at the same time. */ - HASH(true, false), + HASH("hash", true, false), /** * Partition using hash codes, with each partition internally sorted. @@ -42,7 +46,7 @@ public enum ShuffleKind * Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before * consumer stages can run. */ - HASH_LOCAL_SORT(true, true), + HASH_LOCAL_SORT("hashLocalSort", true, true), /** * Partition using a distributed global sort. @@ -58,17 +62,31 @@ public enum ShuffleKind * Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before * consumer stages can run. */ - GLOBAL_SORT(false, true); + GLOBAL_SORT("globalSort", false, true); + private final String name; private final boolean hash; private final boolean sort; - ShuffleKind(boolean hash, boolean sort) + ShuffleKind(String name, boolean hash, boolean sort) { + this.name = name; this.hash = hash; this.sort = sort; } + @JsonCreator + public static ShuffleKind fromString(final String s) + { + for (final ShuffleKind kind : values()) { + if (kind.toString().equals(s)) { + return kind; + } + } + + throw new IAE("No such shuffleKind[%s]", s); + } + /** * Whether this shuffle does hash-partitioning. */ @@ -84,4 +102,11 @@ public boolean isSort() { return sort; } + + @Override + @JsonValue + public String toString() + { + return name; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java index 37f53fca199d..4b7971a7f783 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/ShuffleSpec.java @@ -59,17 +59,13 @@ public interface ShuffleSpec ClusterBy clusterBy(); /** - * Whether this stage aggregates by the {@link #clusterBy()} key. - */ - boolean doesAggregate(); - - /** - * Number of partitions, if known. + * Number of partitions, if known in advance. * * Partition count is always known if {@link #kind()} is {@link ShuffleKind#MIX}, {@link ShuffleKind#HASH}, or - * {@link ShuffleKind#HASH_LOCAL_SORT}. It is not known if {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}. + * {@link ShuffleKind#HASH_LOCAL_SORT}. For {@link ShuffleKind#GLOBAL_SORT}, it is known if we have a single + * output partition. * - * @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT} + * @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT} with more than one target partition */ int partitionCount(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java index 4e212949d5ee..80b912faa8da 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java @@ -345,7 +345,7 @@ public ClusterByStatisticsCollector createResultKeyStatisticsCollector(final int signature, maxRetainedBytes, Limits.MAX_PARTITION_BUCKETS, - shuffleSpec.doesAggregate(), + ((GlobalSortShuffleSpec) shuffleSpec).doesAggregate(), shuffleCheckHasMultipleValues ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java index 35c8dc43665c..5b98eed0da95 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageId.java @@ -21,8 +21,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; -import com.google.common.base.Strings; import org.apache.druid.common.guava.GuavaUtils; +import org.apache.druid.common.utils.IdUtils; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; @@ -43,15 +43,11 @@ public class StageId implements Comparable public StageId(final String queryId, final int stageNumber) { - if (Strings.isNullOrEmpty(queryId)) { - throw new IAE("Null or empty queryId"); - } - if (stageNumber < 0) { throw new IAE("Invalid stageNumber [%s]", stageNumber); } - this.queryId = queryId; + this.queryId = IdUtils.validateId("queryId", queryId); this.stageNumber = stageNumber; } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java index b9a3024048b0..201a1783c05f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; import javax.annotation.Nullable; @@ -46,6 +48,12 @@ public class WorkOrder private final List workerInputs; private final ExtraInfoHolder extraInfoHolder; + @Nullable + private final List workerIds; + + @Nullable + private final OutputChannelMode outputChannelMode; + @JsonCreator @SuppressWarnings("rawtypes") public WorkOrder( @@ -53,7 +61,9 @@ public WorkOrder( @JsonProperty("stage") final int stageNumber, @JsonProperty("worker") final int workerNumber, @JsonProperty("input") final List workerInputs, - @JsonProperty("extra") @Nullable final ExtraInfoHolder extraInfoHolder + @JsonProperty("extra") @Nullable final ExtraInfoHolder extraInfoHolder, + @JsonProperty("workers") @Nullable final List workerIds, + @JsonProperty("output") @Nullable final OutputChannelMode outputChannelMode ) { this.queryDefinition = Preconditions.checkNotNull(queryDefinition, "queryDefinition"); @@ -61,6 +71,8 @@ public WorkOrder( this.workerNumber = workerNumber; this.workerInputs = Preconditions.checkNotNull(workerInputs, "workerInputs"); this.extraInfoHolder = extraInfoHolder; + this.workerIds = workerIds; + this.outputChannelMode = outputChannelMode; } @JsonProperty("query") @@ -95,6 +107,31 @@ ExtraInfoHolder getExtraInfoHolder() return extraInfoHolder; } + /** + * Worker IDs for this query, if known in advance (at the time the work order is created). May be null, in which + * case workers use {@link ControllerClient#getTaskList()} to find worker IDs. + */ + @Nullable + @JsonProperty("workers") + @JsonInclude(JsonInclude.Include.NON_NULL) + public List getWorkerIds() + { + return workerIds; + } + + public boolean hasOutputChannelMode() + { + return outputChannelMode != null; + } + + @Nullable + @JsonProperty("output") + @JsonInclude(JsonInclude.Include.NON_NULL) + public OutputChannelMode getOutputChannelMode() + { + return outputChannelMode; + } + @Nullable public Object getExtraInfo() { @@ -106,6 +143,23 @@ public StageDefinition getStageDefinition() return queryDefinition.getStageDefinition(stageNumber); } + public WorkOrder withOutputChannelMode(final OutputChannelMode newOutputChannelMode) + { + if (newOutputChannelMode == outputChannelMode) { + return this; + } else { + return new WorkOrder( + queryDefinition, + stageNumber, + workerNumber, + workerInputs, + extraInfoHolder, + workerIds, + newOutputChannelMode + ); + } + } + @Override public boolean equals(Object o) { @@ -120,13 +174,23 @@ public boolean equals(Object o) && workerNumber == workOrder.workerNumber && Objects.equals(queryDefinition, workOrder.queryDefinition) && Objects.equals(workerInputs, workOrder.workerInputs) - && Objects.equals(extraInfoHolder, workOrder.extraInfoHolder); + && Objects.equals(extraInfoHolder, workOrder.extraInfoHolder) + && Objects.equals(workerIds, workOrder.workerIds) + && Objects.equals(outputChannelMode, workOrder.outputChannelMode); } @Override public int hashCode() { - return Objects.hash(queryDefinition, stageNumber, workerInputs, workerNumber, extraInfoHolder); + return Objects.hash( + queryDefinition, + stageNumber, + workerNumber, + workerInputs, + extraInfoHolder, + workerIds, + outputChannelMode + ); } @Override @@ -138,6 +202,8 @@ public String toString() ", workerNumber=" + workerNumber + ", workerInputs=" + workerInputs + ", extraInfoHolder=" + extraInfoHolder + + ", workerIds=" + workerIds + + ", outputChannelMode=" + outputChannelMode + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java index 18f1f821d9a6..cdf4e2e20b0b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkerAssignmentStrategy.java @@ -33,12 +33,12 @@ import java.util.OptionalInt; /** - * Strategy for assigning input slices to tasks. Influences how {@link InputSpecSlicer} is used. + * Strategy for assigning input slices to workers. Influences how {@link InputSpecSlicer} is used. */ public enum WorkerAssignmentStrategy { /** - * Use the highest possible number of tasks, while staying within {@link StageDefinition#getMaxWorkerCount()}. + * Use the highest possible number of workers, while staying within {@link StageDefinition#getMaxWorkerCount()}. * * Implemented using {@link InputSpecSlicer#sliceStatic}. */ @@ -57,7 +57,7 @@ public List assign( }, /** - * Use the lowest possible number of tasks, while keeping each task's workload under + * Use the lowest possible number of workers, while keeping each worker's workload under * {@link Limits#MAX_INPUT_FILES_PER_WORKER} files and {@code maxInputBytesPerWorker} bytes. * * Implemented using {@link InputSpecSlicer#sliceDynamic} whenever possible. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java index c7805f04a9f3..05e0f722ccd4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernel.java @@ -20,7 +20,6 @@ package org.apache.druid.msq.kernel.controller; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -33,6 +32,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.QueryValidator; import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.MSQException; @@ -41,7 +41,6 @@ import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerFailedFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; -import org.apache.druid.msq.input.InputSpecSlicer; import org.apache.druid.msq.input.InputSpecSlicerFactory; import org.apache.druid.msq.input.stage.ReadablePartitions; import org.apache.druid.msq.kernel.ExtraInfoHolder; @@ -55,14 +54,17 @@ import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; +import java.util.Queue; import java.util.Set; +import java.util.SortedSet; +import java.util.function.Consumer; import java.util.stream.Collectors; /** @@ -77,39 +79,49 @@ public class ControllerQueryKernel { private static final Logger log = new Logger(ControllerQueryKernel.class); + private final QueryDefinition queryDef; + private final ControllerQueryKernelConfig config; /** * Stage ID -> tracker for that stage. An extension of the state of this kernel. */ - private final Map stageTracker = new HashMap<>(); + private final Map stageTrackers = new HashMap<>(); /** - * Stage ID -> stages that flow *into* that stage. Computed by {@link #computeStageInflowMap}. + * Stage ID -> stages that flow *into* that stage. Computed by {@link ControllerQueryKernelUtils#computeStageInflowMap}. */ private final ImmutableMap> inflowMap; /** - * Stage ID -> stages that *depend on* that stage. Computed by {@link #computeStageOutflowMap}. + * Stage ID -> stages that *depend on* that stage. Computed by {@link ControllerQueryKernelUtils#computeStageOutflowMap}. */ private final ImmutableMap> outflowMap; /** * Maintains a running map of (stageId -> pending inflow stages) which need to be completed to provision the stage * corresponding to the stageId. After initializing, if the value of the entry becomes an empty set, it is removed - * from the map, and the removed entry is added to {@link #readyToRunStages}. + * from the map, and the removed entry is added to {@link #stageGroupQueue}. */ - private final Map> pendingInflowMap; + private final Map> pendingInflowMap; /** * Maintains a running count of (stageId -> outflow stages pending on its results). After initializing, if * the value of the entry becomes an empty set, it is removed from the map and the removed entry is added to * {@link #effectivelyFinishedStages}. */ - private final Map> pendingOutflowMap; + private final Map> pendingOutflowMap; /** - * Tracks those stages which can be initialized safely. + * Stage groups, in the order that we will run them. Each group is a set of stages that internally uses + * {@link OutputChannelMode#MEMORY} for communication. (The final stage may use a different + * {@link OutputChannelMode}. In particular, if a stage group has a single stage, it may use any + * {@link OutputChannelMode}.) + */ + private final Queue stageGroupQueue; + + /** + * Tracks those stages which are ready to begin executing. Populated by {@link #registerStagePhaseChange}. */ private final Set readyToRunStages = new HashSet<>(); @@ -123,7 +135,12 @@ public class ControllerQueryKernel * Map> * Stores the work order per worker per stage so that we can retrieve that in case of worker retry */ - private final Map> stageWorkOrders; + private final Map> stageWorkOrders = new HashMap<>(); + + /** + * Tracks the output channel mode for each stage. + */ + private final Map stageOutputChannelModes = new HashMap<>(); /** * {@link MSQFault#getErrorCode()} which are retried. @@ -133,27 +150,22 @@ public class ControllerQueryKernel UnknownFault.CODE, WorkerRpcFailedFault.CODE ); - private final int maxRetainedPartitionSketchBytes; - private final boolean faultToleranceEnabled; public ControllerQueryKernel( final QueryDefinition queryDef, - int maxRetainedPartitionSketchBytes, - boolean faultToleranceEnabled + final ControllerQueryKernelConfig config ) { this.queryDef = queryDef; - this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; - this.faultToleranceEnabled = faultToleranceEnabled; - this.inflowMap = ImmutableMap.copyOf(computeStageInflowMap(queryDef)); - this.outflowMap = ImmutableMap.copyOf(computeStageOutflowMap(queryDef)); + this.config = config; + this.inflowMap = ImmutableMap.copyOf(ControllerQueryKernelUtils.computeStageInflowMap(queryDef)); + this.outflowMap = ImmutableMap.copyOf(ControllerQueryKernelUtils.computeStageOutflowMap(queryDef)); // pendingInflowMap and pendingOutflowMap are wholly separate from inflowMap, so we can edit the Sets. - this.pendingInflowMap = computeStageInflowMap(queryDef); - this.pendingOutflowMap = computeStageOutflowMap(queryDef); - - stageWorkOrders = new HashMap<>(); + this.pendingInflowMap = ControllerQueryKernelUtils.computeStageInflowMap(queryDef); + this.pendingOutflowMap = ControllerQueryKernelUtils.computeStageOutflowMap(queryDef); + this.stageGroupQueue = new ArrayDeque<>(ControllerQueryKernelUtils.computeStageGroups(queryDef, config)); initializeReadyToRunStages(); } @@ -166,31 +178,24 @@ public List createAndGetNewStageIds( final long maxInputBytesPerWorker ) { - final Int2IntMap stageWorkerCountMap = new Int2IntAVLTreeMap(); - final Int2ObjectMap stagePartitionsMap = new Int2ObjectAVLTreeMap<>(); - - for (final ControllerStageTracker stageKernel : stageTracker.values()) { - final int stageNumber = stageKernel.getStageDefinition().getStageNumber(); - stageWorkerCountMap.put(stageNumber, stageKernel.getWorkerInputs().workerCount()); - - if (stageKernel.hasResultPartitions()) { - stagePartitionsMap.put(stageNumber, stageKernel.getResultPartitions()); - } - } + createNewKernels( + slicerFactory, + assignmentStrategy, + maxInputBytesPerWorker + ); - createNewKernels(stageWorkerCountMap, slicerFactory.makeSlicer(stagePartitionsMap), assignmentStrategy, maxInputBytesPerWorker); - return stageTracker.values() - .stream() - .filter(controllerStageTracker -> controllerStageTracker.getPhase() == ControllerStagePhase.NEW) - .map(stageKernel -> stageKernel.getStageDefinition().getId()) - .collect(Collectors.toList()); + return stageTrackers.values() + .stream() + .filter(controllerStageTracker -> controllerStageTracker.getPhase() == ControllerStagePhase.NEW) + .map(stageTracker -> stageTracker.getStageDefinition().getId()) + .collect(Collectors.toList()); } /** * @return Stage kernels in this query kernel which can be safely cleaned up and marked as FINISHED. This returns the * kernel corresponding to a particular stage only once, to reduce the number of stages to iterate through. * It is expectant of the caller to eventually mark the stage as {@link ControllerStagePhase#FINISHED} after fetching - * the stage kernel + * the stage tracker */ public List getEffectivelyFinishedStageIds() { @@ -202,7 +207,23 @@ public List getEffectivelyFinishedStageIds() */ public List getActiveStages() { - return ImmutableList.copyOf(stageTracker.keySet()); + return ImmutableList.copyOf(stageTrackers.keySet()); + } + + /** + * Returns the number of stages that are active and in non-terminal phases. + */ + public int getNonTerminalActiveStageCount() + { + int n = 0; + + for (final ControllerStageTracker tracker : stageTrackers.values()) { + if (!tracker.getPhase().isTerminal() && tracker.getPhase() != ControllerStagePhase.RESULTS_READY) { + n++; + } + } + + return n; } /** @@ -219,10 +240,8 @@ public StageId getStageId(final int stageNumber) */ public boolean isDone() { - return Optional.ofNullable(stageTracker.get(queryDef.getFinalStageDefinition().getId())) - .filter(tracker -> ControllerStagePhase.isSuccessfulTerminalPhase(tracker.getPhase())) - .isPresent() - || stageTracker.values().stream().anyMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FAILED); + return isSuccess() + || stageTrackers.values().stream().anyMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FAILED); } /** @@ -237,7 +256,7 @@ public void markSuccessfulTerminalStagesAsFinished() // terminal phases" to FINISHED at the end, hence the if clause. Inside the conditional, depending on the // terminal phase it resides in, we synthetically mark it to completion (and therefore we need to check which // stage it is precisely in) - if (ControllerStagePhase.isSuccessfulTerminalPhase(phase)) { + if (phase.isSuccess()) { if (phase == ControllerStagePhase.RESULTS_READY) { finishStage(stageId, false); } @@ -246,14 +265,14 @@ public void markSuccessfulTerminalStagesAsFinished() } /** - * Returns true if all the stages comprising the query definition have been successful in producing their results + * Returns true if all the stages comprising the query definition have been successful in producing their results. */ public boolean isSuccess() { - return stageTracker.size() == queryDef.getStageDefinitions().size() - && stageTracker.values() - .stream() - .allMatch(tracker -> ControllerStagePhase.isSuccessfulTerminalPhase(tracker.getPhase())); + return stageTrackers.size() == queryDef.getStageDefinitions().size() + && stageTrackers.values() + .stream() + .allMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FINISHED); } /** @@ -265,9 +284,10 @@ public Int2ObjectMap createWorkOrders( ) { final Int2ObjectMap workerToWorkOrder = new Int2ObjectAVLTreeMap<>(); - final ControllerStageTracker stageKernel = getStageKernelOrThrow(getStageId(stageNumber)); - + final ControllerStageTracker stageKernel = getStageTrackerOrThrow(getStageId(stageNumber)); final WorkerInputs workerInputs = stageKernel.getWorkerInputs(); + final OutputChannelMode outputChannelMode = stageOutputChannelModes.get(stageKernel.getStageDefinition().getId()); + for (int workerNumber : workerInputs.workers()) { final Object extraInfo = extraInfos != null ? extraInfos.get(workerNumber) : null; @@ -280,7 +300,9 @@ public Int2ObjectMap createWorkOrders( stageNumber, workerNumber, workerInputs.inputsForWorker(workerNumber), - extraInfoHolder + extraInfoHolder, + config.getWorkerIds(), + outputChannelMode ); QueryValidator.validateWorkOrder(workOrder); @@ -291,27 +313,80 @@ public Int2ObjectMap createWorkOrders( } private void createNewKernels( - final Int2IntMap stageWorkerCountMap, - final InputSpecSlicer slicer, + final InputSpecSlicerFactory slicerFactory, final WorkerAssignmentStrategy assignmentStrategy, final long maxInputBytesPerWorker ) { - for (final StageId nextStage : readyToRunStages) { - // Create a tracker. - final StageDefinition stageDef = queryDef.getStageDefinition(nextStage); - final ControllerStageTracker stageKernel = ControllerStageTracker.create( - stageDef, - stageWorkerCountMap, - slicer, - assignmentStrategy, - maxRetainedPartitionSketchBytes, - maxInputBytesPerWorker - ); - stageTracker.put(nextStage, stageKernel); + StageGroup stageGroup; + + while ((stageGroup = stageGroupQueue.peek()) != null) { + if (readyToRunStages.contains(stageGroup.first()) + && getNonTerminalActiveStageCount() + stageGroup.size() <= config.getMaxConcurrentStages()) { + // There is room to launch this stage group. + stageGroupQueue.poll(); + + for (final StageId stageId : stageGroup.stageIds()) { + // Create a tracker for this stage. + stageTrackers.put( + stageId, + createStageTracker( + stageId, + slicerFactory, + assignmentStrategy, + maxInputBytesPerWorker + ) + ); + + // Store output channel mode. + stageOutputChannelModes.put( + stageId, + stageGroup.stageOutputChannelMode(stageId) + ); + } + + stageGroup.stageIds().forEach(readyToRunStages::remove); + } else { + break; + } + } + } + + private ControllerStageTracker createStageTracker( + final StageId stageId, + final InputSpecSlicerFactory slicerFactory, + final WorkerAssignmentStrategy assignmentStrategy, + final long maxInputBytesPerWorker + ) + { + final Int2IntMap stageWorkerCountMap = new Int2IntAVLTreeMap(); + final Int2ObjectMap stagePartitionsMap = new Int2ObjectAVLTreeMap<>(); + final Int2ObjectMap stageOutputChannelModeMap = new Int2ObjectAVLTreeMap<>(); + + for (final ControllerStageTracker stageTracker : stageTrackers.values()) { + final int stageNumber = stageTracker.getStageDefinition().getStageNumber(); + stageWorkerCountMap.put(stageNumber, stageTracker.getWorkerInputs().workerCount()); + + if (stageTracker.hasResultPartitions()) { + stagePartitionsMap.put(stageNumber, stageTracker.getResultPartitions()); + } + + final OutputChannelMode outputChannelMode = + stageOutputChannelModes.get(stageTracker.getStageDefinition().getId()); + + if (outputChannelMode != null) { + stageOutputChannelModeMap.put(stageNumber, outputChannelMode); + } } - readyToRunStages.clear(); + return ControllerStageTracker.create( + getStageDefinition(stageId), + stageWorkerCountMap, + slicerFactory.makeSlicer(stagePartitionsMap, stageOutputChannelModeMap), + assignmentStrategy, + config.getMaxRetainedPartitionSketchBytes(), + maxInputBytesPerWorker + ); } /** @@ -320,33 +395,81 @@ private void createNewKernels( */ private void initializeReadyToRunStages() { - final Iterator>> pendingInflowIterator = pendingInflowMap.entrySet().iterator(); + final List readyStages = new ArrayList<>(); + final Iterator>> pendingInflowIterator = + pendingInflowMap.entrySet().iterator(); while (pendingInflowIterator.hasNext()) { - Map.Entry> stageToInflowStages = pendingInflowIterator.next(); - if (stageToInflowStages.getValue().size() == 0) { - readyToRunStages.add(stageToInflowStages.getKey()); + final Map.Entry> stageToInflowStages = pendingInflowIterator.next(); + if (stageToInflowStages.getValue().isEmpty()) { + readyStages.add(stageToInflowStages.getKey()); pendingInflowIterator.remove(); } } - } - // Following section contains the methods which delegate to appropriate stage kernel + readyToRunStages.addAll(readyStages); + } /** - * Delegates call to {@link ControllerStageTracker#getStageDefinition()} + * Returns the definition of a given stage. + * + * @throws NullPointerException if there is no stage with the given ID */ public StageDefinition getStageDefinition(final StageId stageId) { - return getStageKernelOrThrow(stageId).getStageDefinition(); + return queryDef.getStageDefinition(stageId); + } + + /** + * Returns the {@link OutputChannelMode} for a given stage. + * + * @throws IllegalStateException if there is no stage with the given ID + */ + public OutputChannelMode getStageOutputChannelMode(final StageId stageId) + { + final OutputChannelMode outputChannelMode = stageOutputChannelModes.get(stageId); + if (outputChannelMode == null) { + throw new ISE("No such stage[%s]", stageId); + } + + return outputChannelMode; } + /** + * Whether query results are readable. + */ + public boolean canReadQueryResults() + { + final StageId finalStageId = queryDef.getFinalStageDefinition().getId(); + final ControllerStageTracker stageTracker = stageTrackers.get(finalStageId); + if (stageTracker == null) { + return false; + } else { + final OutputChannelMode outputChannelMode = stageOutputChannelModes.get(finalStageId); + if (outputChannelMode == OutputChannelMode.MEMORY) { + return stageTracker.getPhase().isRunning(); + } else { + return stageTracker.getPhase() == ControllerStagePhase.RESULTS_READY; + } + } + } + + // Following section contains the methods which delegate to appropriate stage kernel + /** * Delegates call to {@link ControllerStageTracker#getPhase()} */ public ControllerStagePhase getStagePhase(final StageId stageId) { - return getStageKernelOrThrow(stageId).getPhase(); + return getStageTrackerOrThrow(stageId).getPhase(); + } + + /** + * Returns whether a particular stage is finished. Stages can finish early if their outputs are no longer needed. + */ + public boolean isStageFinished(final StageId stageId) + { + return getStagePhase(stageId) == ControllerStagePhase.FINISHED; } /** @@ -354,7 +477,7 @@ public ControllerStagePhase getStagePhase(final StageId stageId) */ public boolean doesStageHaveResultPartitions(final StageId stageId) { - return getStageKernelOrThrow(stageId).hasResultPartitions(); + return getStageTrackerOrThrow(stageId).hasResultPartitions(); } /** @@ -362,7 +485,7 @@ public boolean doesStageHaveResultPartitions(final StageId stageId) */ public ReadablePartitions getResultPartitionsForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getResultPartitions(); + return getStageTrackerOrThrow(stageId).getResultPartitions(); } /** @@ -370,7 +493,7 @@ public ReadablePartitions getResultPartitionsForStage(final StageId stageId) */ public IntSet getWorkersToSendPartitionBoundaries(final StageId stageId) { - return getStageKernelOrThrow(stageId).getWorkersToSendPartitionBoundaries(); + return getStageTrackerOrThrow(stageId).getWorkersToSendPartitionBoundaries(); } /** @@ -378,7 +501,7 @@ public IntSet getWorkersToSendPartitionBoundaries(final StageId stageId) */ public void workOrdersSentForWorker(final StageId stageId, int worker) { - getStageKernelOrThrow(stageId).workOrderSentForWorker(worker); + doWithStageTracker(stageId, stageTracker -> stageTracker.workOrderSentForWorker(worker)); } /** @@ -386,7 +509,7 @@ public void workOrdersSentForWorker(final StageId stageId, int worker) */ public void partitionBoundariesSentForWorker(final StageId stageId, int worker) { - getStageKernelOrThrow(stageId).partitionBoundariesSentForWorker(worker); + doWithStageTracker(stageId, stageTracker -> stageTracker.partitionBoundariesSentForWorker(worker)); } /** @@ -394,7 +517,7 @@ public void partitionBoundariesSentForWorker(final StageId stageId, int worker) */ public ClusterByPartitions getResultPartitionBoundariesForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getResultPartitionBoundaries(); + return getStageTrackerOrThrow(stageId).getResultPartitionBoundaries(); } /** @@ -402,7 +525,7 @@ public ClusterByPartitions getResultPartitionBoundariesForStage(final StageId st */ public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation(final StageId stageId) { - return getStageKernelOrThrow(stageId).getCompleteKeyStatisticsInformation(); + return getStageTrackerOrThrow(stageId).getCompleteKeyStatisticsInformation(); } /** @@ -410,7 +533,7 @@ public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation(fina */ public boolean hasStageCollectorEncounteredAnyMultiValueField(final StageId stageId) { - return getStageKernelOrThrow(stageId).collectorEncounteredAnyMultiValueField(); + return getStageTrackerOrThrow(stageId).collectorEncounteredAnyMultiValueField(); } /** @@ -418,7 +541,7 @@ public boolean hasStageCollectorEncounteredAnyMultiValueField(final StageId stag */ public Object getResultObjectForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getResultObject(); + return getStageTrackerOrThrow(stageId).getResultObject(); } /** @@ -427,15 +550,17 @@ public Object getResultObjectForStage(final StageId stageId) */ public void startStage(final StageId stageId) { - final ControllerStageTracker stageKernel = getStageKernelOrThrow(stageId); - if (stageKernel.getPhase() != ControllerStagePhase.NEW) { - throw new ISE("Cannot start the stage: [%s]", stageId); - } if (stageWorkOrders.get(stageId) == null) { - throw new ISE("Work orders not present for stage %s", stageId); + throw new ISE("Work order not present for stage[%s]", stageId); } - stageKernel.start(); - transitionStageKernel(stageId, ControllerStagePhase.READING_INPUT); + + doWithStageTracker(stageId, stageTracker -> { + if (stageTracker.getPhase() != ControllerStagePhase.NEW) { + throw new ISE("Cannot start the stage: [%s]", stageId); + } + + stageTracker.start(); + }); } /** @@ -450,9 +575,10 @@ public void finishStage(final StageId stageId, final boolean strict) if (strict && !effectivelyFinishedStages.contains(stageId)) { throw new IAE("Cannot mark the stage: [%s] finished", stageId); } - getStageKernelOrThrow(stageId).finish(); - effectivelyFinishedStages.remove(stageId); - transitionStageKernel(stageId, ControllerStagePhase.FINISHED); + doWithStageTracker(stageId, stageTracker -> { + stageTracker.finish(); + effectivelyFinishedStages.remove(stageId); + }); stageWorkOrders.remove(stageId); } @@ -461,7 +587,7 @@ public void finishStage(final StageId stageId, final boolean strict) */ public WorkerInputs getWorkerInputsForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getWorkerInputs(); + return getStageTrackerOrThrow(stageId).getWorkerInputs(); } /** @@ -474,20 +600,17 @@ public void addPartialKeyStatisticsForStageAndWorker( final PartialKeyStatisticsInformation partialKeyStatisticsInformation ) { - ControllerStageTracker stageKernel = getStageKernelOrThrow(stageId); - ControllerStagePhase newPhase = stageKernel.addPartialKeyInformationForWorker( - workerNumber, - partialKeyStatisticsInformation - ); + doWithStageTracker(stageId, stageTracker -> + stageTracker.addPartialKeyInformationForWorker(workerNumber, partialKeyStatisticsInformation)); + } - // If the kernel phase has transitioned, we need to account for that. - switch (newPhase) { - case MERGING_STATISTICS: - case POST_READING: - case FAILED: - transitionStageKernel(stageId, newPhase); - break; - } + /** + * Delegates call to {@link ControllerStageTracker#addPartialKeyInformationForWorker(int, PartialKeyStatisticsInformation)}. + * If calling this causes transition for the stage kernel, then this gets registered in this query kernel + */ + public void setDoneReadingInputForStageAndWorker(final StageId stageId, final int workerNumber) + { + doWithStageTracker(stageId, stageTracker -> stageTracker.setDoneReadingInputForWorker(workerNumber)); } /** @@ -500,9 +623,7 @@ public void setResultsCompleteForStageAndWorker( final Object resultObject ) { - if (getStageKernelOrThrow(stageId).setResultsCompleteForWorker(workerNumber, resultObject)) { - transitionStageKernel(stageId, ControllerStagePhase.RESULTS_READY); - } + doWithStageTracker(stageId, stageTracker -> stageTracker.setResultsCompleteForWorker(workerNumber, resultObject)); } /** @@ -510,13 +631,7 @@ public void setResultsCompleteForStageAndWorker( */ public MSQFault getFailureReasonForStage(final StageId stageId) { - return getStageKernelOrThrow(stageId).getFailureReason(); - } - - public void failStageForReason(final StageId stageId, MSQFault fault) - { - getStageKernelOrThrow(stageId).failForReason(fault); - transitionStageKernel(stageId, ControllerStagePhase.FAILED); + return getStageTrackerOrThrow(stageId).getFailureReason(); } /** @@ -524,20 +639,19 @@ public void failStageForReason(final StageId stageId, MSQFault fault) */ public void failStage(final StageId stageId) { - getStageKernelOrThrow(stageId).fail(); - transitionStageKernel(stageId, ControllerStagePhase.FAILED); + doWithStageTracker(stageId, ControllerStageTracker::fail); } /** * Fetches and returns the stage kernel corresponding to the provided stage id, else throws {@link IAE} */ - private ControllerStageTracker getStageKernelOrThrow(StageId stageId) + private ControllerStageTracker getStageTrackerOrThrow(StageId stageId) { - ControllerStageTracker stageKernel = stageTracker.get(stageId); - if (stageKernel == null) { + ControllerStageTracker stageTracker = stageTrackers.get(stageId); + if (stageTracker == null) { throw new IAE("Cannot find kernel corresponding to stage [%s] in query [%s]", stageId, queryDef.getQueryId()); } - return stageKernel; + return stageTracker; } private WorkOrder getWorkOrder(int workerNumber, StageId stageId) @@ -556,99 +670,99 @@ private WorkOrder getWorkOrder(int workerNumber, StageId stageId) } /** - * Whenever a stage kernel changes its phase, the change must be "registered" by calling this method with the stageId - * and the new phase + * Whether a given stage is ready to stream results to consumer stages upon transition to "newPhase". */ - public void transitionStageKernel(StageId stageId, ControllerStagePhase newPhase) + private boolean readyToReadResults(final StageId stageId, final ControllerStagePhase newPhase) { - Preconditions.checkArgument( - stageTracker.containsKey(stageId), - "Attempting to modify an unknown stageKernel" - ); + if (stageOutputChannelModes.get(stageId) == OutputChannelMode.MEMORY) { + if (getStageDefinition(stageId).doesSortDuringShuffle()) { + // Stages that sort during shuffle go through a READING_INPUT phase followed by a POST_READING phase + // (once all input is read). These stages start producing output once POST_READING starts. + return newPhase == ControllerStagePhase.POST_READING; + } else { + // Can read results immediately. + return newPhase == ControllerStagePhase.NEW; + } + } else { + return newPhase == ControllerStagePhase.RESULTS_READY; + } + } + + private void doWithStageTracker(final StageId stageId, final Consumer fn) + { + final ControllerStageTracker stageTracker = getStageTrackerOrThrow(stageId); + final ControllerStagePhase phase = stageTracker.getPhase(); + fn.accept(stageTracker); + + if (phase != stageTracker.getPhase()) { + registerStagePhaseChange(stageId, stageTracker.getPhase()); + } + } - if (newPhase == ControllerStagePhase.RESULTS_READY) { - // Once the stage has produced its results, we remove it from all the stages depending on this stage (for its - // output). + /** + * Whenever a stage kernel changes its phase, the change must be "registered" by calling this method with the stageId + * and the new phase. + */ + private void registerStagePhaseChange(final StageId stageId, final ControllerStagePhase newPhase) + { + if (readyToReadResults(stageId, newPhase)) { + // Once results from a stage are readable, remove this stage from pendingInflowMap and potentially mark + // dependent stages as ready to run. for (StageId dependentStageId : outflowMap.get(stageId)) { if (!pendingInflowMap.containsKey(dependentStageId)) { continue; } pendingInflowMap.get(dependentStageId).remove(stageId); // Check the dependent stage. If it has no dependencies left, it can be marked as to be initialized - if (pendingInflowMap.get(dependentStageId).size() == 0) { + if (pendingInflowMap.get(dependentStageId).isEmpty()) { readyToRunStages.add(dependentStageId); pendingInflowMap.remove(dependentStageId); } } } - if (ControllerStagePhase.isPostReadingPhase(newPhase)) { - - // when fault tolerance is enabled, we cannot delete the input data eagerly as we need the input stage for retry until - // results for the current stage are ready. - if (faultToleranceEnabled && newPhase == ControllerStagePhase.POST_READING) { - return; - } - // Once the stage has consumed all the data/input from its dependent stages, we remove it from all the stages - // whose input it was dependent on + if (newPhase.isSuccess() || (!config.isFaultTolerant() && newPhase.isDoneReadingInput())) { + // Once a stage no longer needs its input, we consider marking input stages as finished. for (StageId inputStage : inflowMap.get(stageId)) { if (!pendingOutflowMap.containsKey(inputStage)) { continue; } pendingOutflowMap.get(inputStage).remove(stageId); - // If no more stage is dependent on the "inputStage's" results, it can be safely transitioned to FINISHED - if (pendingOutflowMap.get(inputStage).size() == 0) { - effectivelyFinishedStages.add(inputStage); + // If no more stage is dependent on the inputStage's results, it can be safely transitioned to FINISHED + if (pendingOutflowMap.get(inputStage).isEmpty()) { pendingOutflowMap.remove(inputStage); + + // Mark input stage as effectively finished, if it's ready to finish. + // This leads to a later transition to FINISHED. + if (ControllerStagePhase.FINISHED.canTransitionFrom(stageTrackers.get(inputStage).getPhase())) { + effectivelyFinishedStages.add(inputStage); + } } } } - } - @VisibleForTesting - ControllerStageTracker getControllerStageKernel(int stageNumber) - { - return stageTracker.get(new StageId(queryDef.getQueryId(), stageNumber)); - } - - /** - * Returns a mapping of stage -> stages that flow *into* that stage. - */ - private static Map> computeStageInflowMap(final QueryDefinition queryDefinition) - { - final Map> retVal = new HashMap<>(); + // Mark stage as effectively finished, if it has no dependencies waiting for it. + // This leads to a later transition to FINISHED. + final boolean hasDependentStages = + pendingOutflowMap.containsKey(stageId) && !pendingOutflowMap.get(stageId).isEmpty(); - for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { - final StageId stageId = stageDef.getId(); - retVal.computeIfAbsent(stageId, ignored -> new HashSet<>()); + if (!hasDependentStages) { + final boolean isFinalStage = queryDef.getFinalStageDefinition().getId().equals(stageId); - for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { - final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); - retVal.computeIfAbsent(stageId, ignored -> new HashSet<>()).add(inputStageId); + if (isFinalStage && newPhase == ControllerStagePhase.RESULTS_READY) { + // Final stage must run to completion (RESULTS_READY). + effectivelyFinishedStages.add(stageId); + } else if (!isFinalStage && ControllerStagePhase.FINISHED.canTransitionFrom(newPhase)) { + // Other stages can exit early (e.g. if there is a LIMIT). + effectivelyFinishedStages.add(stageId); } } - - return retVal; } - /** - * Returns a mapping of stage -> stages that depend on that stage. - */ - private static Map> computeStageOutflowMap(final QueryDefinition queryDefinition) + @VisibleForTesting + ControllerStageTracker getControllerStageTracker(int stageNumber) { - final Map> retVal = new HashMap<>(); - - for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { - final StageId stageId = stageDef.getId(); - retVal.computeIfAbsent(stageId, ignored -> new HashSet<>()); - - for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { - final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); - retVal.computeIfAbsent(inputStageId, ignored -> new HashSet<>()).add(stageId); - } - } - - return retVal; + return stageTrackers.get(new StageId(queryDef.getQueryId(), stageNumber)); } /** @@ -660,6 +774,7 @@ private static Map> computeStageOutflowMap(final QueryDefi * * @param workerNumber * @param msqFault + * * @return List of {@link WorkOrder} that needs to be retried. */ public List getWorkInCaseWorkerEligibleForRetryElseThrow(int workerNumber, MSQFault msqFault) @@ -691,23 +806,23 @@ public static boolean isRetriableFault(MSQFault msqFault) * If yes adds the workOrder for that stage to the return list and transitions the stage kernel to {@link ControllerStagePhase#RETRYING} * * @param worker + * * @return List of {@link WorkOrder} that needs to be retried. */ private List getWorkInCaseWorkerEligibleForRetry(int worker) { List trackedSet = new ArrayList<>(getActiveStages()); - trackedSet.removeAll(getEffectivelyFinishedStageIds()); + trackedSet.removeAll(effectivelyFinishedStages); List workOrders = new ArrayList<>(); for (StageId stageId : trackedSet) { - ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId); - if (ControllerStagePhase.RETRYING.canTransitionFrom(controllerStageTracker.getPhase()) - && controllerStageTracker.retryIfNeeded(worker)) { - workOrders.add(getWorkOrder(worker, stageId)); - // should be a no-op. - transitionStageKernel(stageId, ControllerStagePhase.RETRYING); - } + doWithStageTracker(stageId, stageTracker -> { + if (ControllerStagePhase.RETRYING.canTransitionFrom(stageTracker.getPhase()) + && stageTracker.retryIfNeeded(worker)) { + workOrders.add(getWorkOrder(worker, stageId)); + } + }); } return workOrders; } @@ -723,7 +838,7 @@ public Map> getStagesAndWorkersToFetchClusterStats() Map> stageToWorkers = new HashMap<>(); for (StageId stageId : trackedSet) { - ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId); + ControllerStageTracker controllerStageTracker = getStageTrackerOrThrow(stageId); if (controllerStageTracker.getStageDefinition().mustGatherResultKeyStatistics()) { stageToWorkers.put(stageId, controllerStageTracker.getWorkersToFetchClusterStatisticsFrom()); } @@ -737,11 +852,11 @@ public Map> getStagesAndWorkersToFetchClusterStats() */ public void startFetchingStatsFromWorker(StageId stageId, Set workers) { - ControllerStageTracker controllerStageTracker = getStageKernelOrThrow(stageId); - - for (int worker : workers) { - controllerStageTracker.startFetchingStatsFromWorker(worker); - } + doWithStageTracker(stageId, stageTracker -> { + for (int worker : workers) { + stageTracker.startFetchingStatsFromWorker(worker); + } + }); } /** @@ -753,10 +868,8 @@ public void mergeClusterByStatisticsCollectorForAllTimeChunks( ClusterByStatisticsSnapshot clusterByStatsSnapshot ) { - getStageKernelOrThrow(stageId).mergeClusterByStatisticsCollectorForAllTimeChunks( - workerNumber, - clusterByStatsSnapshot - ); + doWithStageTracker(stageId, stageTracker -> + stageTracker.mergeClusterByStatisticsCollectorForAllTimeChunks(workerNumber, clusterByStatsSnapshot)); } /** @@ -770,11 +883,8 @@ public void mergeClusterByStatisticsCollectorForTimeChunk( ClusterByStatisticsSnapshot clusterByStatsSnapshot ) { - getStageKernelOrThrow(stageId).mergeClusterByStatisticsCollectorForTimeChunk( - workerNumber, - timeChunk, - clusterByStatsSnapshot - ); + doWithStageTracker(stageId, stageTracker -> + stageTracker.mergeClusterByStatisticsCollectorForTimeChunk(workerNumber, timeChunk, clusterByStatsSnapshot)); } /** @@ -782,7 +892,7 @@ public void mergeClusterByStatisticsCollectorForTimeChunk( */ public boolean allPartialKeyInformationPresent(StageId stageId) { - return getStageKernelOrThrow(stageId).allPartialKeyInformationFetched(); + return getStageTrackerOrThrow(stageId).allPartialKeyInformationFetched(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java new file mode 100644 index 000000000000..5c754aedd4f4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelConfig.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.indexing.destination.MSQDestination; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Configuration for {@link ControllerQueryKernel}. + */ +public class ControllerQueryKernelConfig +{ + private final int maxRetainedPartitionSketchBytes; + private final int maxConcurrentStages; + private final boolean pipeline; + private final boolean durableStorage; + private final boolean faultTolerance; + private final MSQDestination destination; + + @Nullable + private final String controllerId; + + @Nullable + private final List workerIds; + + private ControllerQueryKernelConfig( + int maxRetainedPartitionSketchBytes, + int maxConcurrentStages, + boolean pipeline, + boolean durableStorage, + boolean faultTolerance, + MSQDestination destination, + @Nullable String controllerId, + @Nullable List workerIds + ) + { + if (maxRetainedPartitionSketchBytes <= 0) { + throw new IAE("maxRetainedPartitionSketchBytes must be positive"); + } + + if (pipeline && maxConcurrentStages < 2) { + throw new IAE("maxConcurrentStagesPerWorker must be >= 2 when pipelining"); + } + + if (maxConcurrentStages <= 0) { + throw new IAE("maxConcurrentStagesPerWorker must be positive"); + } + + if (pipeline && faultTolerance) { + throw new IAE("Cannot pipeline with fault tolerance"); + } + + if (pipeline && durableStorage) { + throw new IAE("Cannot pipeline with durable storage"); + } + + if (faultTolerance && !durableStorage) { + throw new IAE("Cannot have fault tolerance without durable storage"); + } + + this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; + this.maxConcurrentStages = maxConcurrentStages; + this.pipeline = pipeline; + this.durableStorage = durableStorage; + this.faultTolerance = faultTolerance; + this.destination = destination; + this.controllerId = controllerId; + this.workerIds = workerIds; + } + + public static Builder builder() + { + return new Builder(); + } + + public int getMaxRetainedPartitionSketchBytes() + { + return maxRetainedPartitionSketchBytes; + } + + public int getMaxConcurrentStages() + { + return maxConcurrentStages; + } + + public boolean isPipeline() + { + return pipeline; + } + + public boolean isDurableStorage() + { + return durableStorage; + } + + public boolean isFaultTolerant() + { + return faultTolerance; + } + + public MSQDestination getDestination() + { + return destination; + } + + @Nullable + public List getWorkerIds() + { + return workerIds; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ControllerQueryKernelConfig that = (ControllerQueryKernelConfig) o; + return maxRetainedPartitionSketchBytes == that.maxRetainedPartitionSketchBytes + && maxConcurrentStages == that.maxConcurrentStages + && pipeline == that.pipeline + && durableStorage == that.durableStorage + && faultTolerance == that.faultTolerance + && Objects.equals(controllerId, that.controllerId) + && Objects.equals(workerIds, that.workerIds); + } + + @Override + public int hashCode() + { + return Objects.hash( + maxRetainedPartitionSketchBytes, + maxConcurrentStages, + pipeline, + durableStorage, + faultTolerance, + controllerId, + workerIds + ); + } + + @Override + public String toString() + { + return "ControllerQueryKernelConfig{" + + "maxRetainedPartitionSketchBytes=" + maxRetainedPartitionSketchBytes + + ", maxConcurrentStages=" + maxConcurrentStages + + ", pipeline=" + pipeline + + ", durableStorage=" + durableStorage + + ", faultTolerant=" + faultTolerance + + ", controllerId='" + controllerId + '\'' + + ", workerIds=" + workerIds + + '}'; + } + + public static class Builder + { + private int maxRetainedPartitionSketchBytes = -1; + private int maxConcurrentStages = 1; + private boolean pipeline; + private boolean durableStorage; + private boolean faultTolerant; + private MSQDestination destination; + private String controllerId; + private List workerIds; + + /** + * Use {@link #builder()}. + */ + private Builder() + { + } + + public Builder maxRetainedPartitionSketchBytes(final int maxRetainedPartitionSketchBytes) + { + this.maxRetainedPartitionSketchBytes = maxRetainedPartitionSketchBytes; + return this; + } + + public Builder maxConcurrentStages(final int maxConcurrentStages) + { + this.maxConcurrentStages = maxConcurrentStages; + return this; + } + + public Builder pipeline(final boolean pipeline) + { + this.pipeline = pipeline; + return this; + } + + public Builder durableStorage(final boolean durableStorage) + { + this.durableStorage = durableStorage; + return this; + } + + public Builder faultTolerance(final boolean faultTolerant) + { + this.faultTolerant = faultTolerant; + return this; + } + + public Builder destination(final MSQDestination destination) + { + this.destination = destination; + return this; + } + + public Builder controllerId(final String controllerId) + { + this.controllerId = controllerId; + return this; + } + + public Builder workerIds(final List workerIds) + { + this.workerIds = workerIds; + return this; + } + + public ControllerQueryKernelConfig build() + { + return new ControllerQueryKernelConfig( + maxRetainedPartitionSketchBytes, + maxConcurrentStages, + pipeline, + durableStorage, + faultTolerant, + destination, + controllerId, + workerIds + ); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java new file mode 100644 index 000000000000..d971f33a9f2d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtils.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.indexing.destination.MSQDestination; +import org.apache.druid.msq.indexing.destination.MSQSelectDestination; +import org.apache.druid.msq.input.InputSpec; +import org.apache.druid.msq.input.InputSpecs; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.kernel.StageId; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; + +/** + * Utilties for {@link ControllerQueryKernel}. + */ +public class ControllerQueryKernelUtils +{ + /** + * Put stages from {@link QueryDefinition} into groups that must each be launched simultaneously. + * + * This method's goal is to maximize the usage of {@link OutputChannelMode#MEMORY} channels, subject to constraints + * provided by {@link ControllerQueryKernelConfig#isPipeline()}, + * {@link ControllerQueryKernelConfig#getMaxConcurrentStages()}, and + * {@link ControllerQueryKernelConfig#isFaultTolerant()}. + * + * An important part of the logic here is determining the output channel mode of the final stage in a group, i.e. + * {@link StageGroup#lastStageOutputChannelMode()}. + * + * If the {@link StageGroup#lastStageOutputChannelMode()} is not {@link OutputChannelMode#MEMORY}, then the stage + * group is fully executed, and fully generates its output, prior to any downstream stage groups starting. + * + * On the other hand, if {@link StageGroup#lastStageOutputChannelMode()} is {@link OutputChannelMode#MEMORY}, the + * stage group executes up to such a point that the group's last stage has results ready-to-read; see + * {@link ControllerQueryKernel#readyToReadResults(StageId, ControllerStagePhase)}. A downstream stage group, if any, + * is started while the current group is still running. This enables them to transfer data in memory. + * + * Stage groups always end when some stage in them sorts during shuffle, i.e. returns true from + * ({@link StageDefinition#doesSortDuringShuffle()}). This enables "leapfrog" execution, where a sequence + * of sorting stages in separate groups can all run with {@link OutputChannelMode#MEMORY}, even when there are more + * stages than the maxConcurrentStages parameter. To achieve this, we wind down upstream stage groups prior to + * starting downstream stage groups, such that only two groups are ever running at a time. + * + * For example, consider a case where pipeline = true and maxConcurrentStages = 2, and the query has three stages, + * all of which sort during shuffle. The expected return from this method is a list of 3 stage groups, each with + * one stage, and each with {@link StageGroup#lastStageOutputChannelMode()} set to {@link OutputChannelMode#MEMORY}. + * To stay within maxConcurrentStages = 2, execution leapfrogs in the following manner (note- not all transitions + * are shown here, for brevity): + * + *

    + *
  1. Stage 0 starts
  2. + *
  3. Stage 0 enters {@link ControllerStagePhase#POST_READING}, finishes sorting
  4. + *
  5. Stage 1 enters {@link ControllerStagePhase#READING_INPUT}
  6. + *
  7. Stage 1 enters {@link ControllerStagePhase#POST_READING}, finishes sorting
  8. + *
  9. Stage 0 stops, ends in {@link ControllerStagePhase#FINISHED})
  10. + *
  11. Stage 2 starts
  12. + *
  13. Stage 2 enters {@link ControllerStagePhase#POST_READING}, finishes sorting
  14. + *
  15. Stage 1 stops, ends in {@link ControllerStagePhase#FINISHED})
  16. + *
  17. Stage 2 stops and query completes
  18. + *
+ * + * When maxConcurrentStages = 2, leapfrogging is only possible with stage groups containing a single stage each. + * When maxConcurrentStages > 2, leapfrogging can happen with larger stage groups containing longer chains. + */ + public static List computeStageGroups( + final QueryDefinition queryDef, + final ControllerQueryKernelConfig config + ) + { + final MSQDestination destination = config.getDestination(); + final List stageGroups = new ArrayList<>(); + final boolean useDurableStorage = config.isDurableStorage(); + final Map> inflow = computeStageInflowMap(queryDef); + final Map> outflow = computeStageOutflowMap(queryDef); + final Set stagesRun = new HashSet<>(); + + // This loop simulates execution of all stages, such that we arrive at an order of execution that is compatible + // with all relevant constraints. + + while (stagesRun.size() < queryDef.getStageDefinitions().size()) { + // 1) Find all stages that are ready to run, and that cannot use MEMORY output modes. Run them as solo groups. + boolean didRun; + do { + didRun = false; + + for (final StageId stageId : ImmutableList.copyOf(inflow.keySet())) { + if (!stagesRun.contains(stageId) /* stage has not run yet */ + && inflow.get(stageId).isEmpty() /* stage data is fully available */ + && !canUseMemoryOutput(queryDef, stageId.getStageNumber(), config, outflow)) { + stagesRun.add(stageId); + stageGroups.add( + new StageGroup( + Collections.singletonList(stageId), + getOutputChannelMode( + queryDef, + stageId.getStageNumber(), + destination.toSelectDestination(), + useDurableStorage, + false + ) + ) + ); + + // Simulate this stage finishing. + removeStageFlow(stageId, inflow, outflow); + didRun = true; + } + } + } while (didRun); + + // 2) Pick some stage that can use MEMORY output mode, and run that as well as all ready-to-run dependents. + StageId currentStageId = null; + for (final StageId stageId : ImmutableList.copyOf(inflow.keySet())) { + if (!stagesRun.contains(stageId) + && inflow.get(stageId).isEmpty() + && canUseMemoryOutput(queryDef, stageId.getStageNumber(), config, outflow)) { + currentStageId = stageId; + break; + } + } + + if (currentStageId == null) { + // Didn't find a stage that can use MEMORY output mode. + continue; + } + + // Found a stage that can use MEMORY output mode. Build a maximally-sized StageGroup around it. + final List currentStageGroup = new ArrayList<>(); + + // maxStageGroupSize => largest size this stage group can be while respecting maxConcurrentStages and leaving + // room for a priorGroup to run concurrently (if priorGroup uses MEMORY output mode). + final int maxStageGroupSize; + + if (stageGroups.isEmpty()) { + maxStageGroupSize = config.getMaxConcurrentStages(); + } else { + final StageGroup priorGroup = stageGroups.get(stageGroups.size() - 1); + if (priorGroup.lastStageOutputChannelMode() == OutputChannelMode.MEMORY) { + // Prior group runs concurrently with this group. (Can happen when leapfrogging; see class-level Javadoc.) + + // Note: priorGroup.size() is strictly less than config.getMaxConcurrentStages(), because the prior group + // would have its size limited by maxStageGroupSizeAllowingForDownstreamConsumer below. + + maxStageGroupSize = config.getMaxConcurrentStages() - priorGroup.size(); + } else { + // Prior group exits before this group starts. + maxStageGroupSize = config.getMaxConcurrentStages(); + } + } + + OutputChannelMode currentOutputChannelMode = null; + while (currentStageId != null) { + final boolean canUseMemoryOuput = + canUseMemoryOutput(queryDef, currentStageId.getStageNumber(), config, outflow); + final Set currentOutflow = outflow.get(currentStageId); + + // maxStageGroupSizeAllowingForDownstreamConsumer => largest size this stage group can be while still being + // able to use MEMORY output mode. (With MEMORY output mode, the downstream consumer must run concurrently.) + final int maxStageGroupSizeAllowingForDownstreamConsumer; + + if (queryDef.getStageDefinition(currentStageId).doesSortDuringShuffle()) { + // When the current group sorts, there's a pipeline break, so we can leapfrog: close the prior group before + // starting the downstream group. In this case, we only need to reserve a single concurrent-stage slot for + // a downstream consumer. + + // Note: the only stage that can possibly sort is the final stage, because of the check below that closes + // off the stage group if currentStageId "doesSortDuringShuffle()". + + maxStageGroupSizeAllowingForDownstreamConsumer = config.getMaxConcurrentStages() - 1; + } else { + // When the final stage in the current group doesn't sort, we can't leapfrog. We need to reserve a single + // concurrent-stage slot for a downstream consumer, plus enough space to keep the priorGroup running (which + // is accounted for in maxStageGroupSize). + maxStageGroupSizeAllowingForDownstreamConsumer = maxStageGroupSize - 1; + } + + currentOutputChannelMode = + getOutputChannelMode( + queryDef, + currentStageId.getStageNumber(), + config.getDestination().toSelectDestination(), + config.isDurableStorage(), + canUseMemoryOuput + + // Stages can only use MEMORY output mode if they have <= 1 downstream stage (checked by + // "canUseMemoryOutput") and if that downstream stage has all of its other inputs available. + && (currentOutflow.isEmpty() || + Collections.singleton(currentStageId) + .equals(inflow.get(Iterables.getOnlyElement(currentOutflow)))) + + // And, stages can only use MEMORY output mode if their downstream consumer is able to start + // concurrently. So, once the stage group gets too big, we must stop using MEMORY output mode. + && (currentOutflow.isEmpty() + || currentStageGroup.size() < maxStageGroupSizeAllowingForDownstreamConsumer) + ); + + currentStageGroup.add(currentStageId); + + if (currentOutflow.size() == 1 + && currentStageGroup.size() < maxStageGroupSize + && currentOutputChannelMode == OutputChannelMode.MEMORY + + // Sorting causes a pipeline break: a sorting stage must read all its input before writing any output. + // Continue the stage group only if this stage does not sort its output. + && !queryDef.getStageDefinition(currentStageId).doesSortDuringShuffle()) { + currentStageId = Iterables.getOnlyElement(currentOutflow); + } else { + currentStageId = null; + } + } + + stageGroups.add(new StageGroup(currentStageGroup, currentOutputChannelMode)); + + // Simulate this stage group finishing. + for (final StageId stageId : currentStageGroup) { + stagesRun.add(stageId); + removeStageFlow(stageId, inflow, outflow); + } + } + + return stageGroups; + } + + /** + * Returns a mapping of stage -> stages that flow *into* that stage. Uses TreeMaps and TreeSets so we have a + * consistent order for analyzing and running stages. + */ + public static Map> computeStageInflowMap(final QueryDefinition queryDefinition) + { + final Map> retVal = new TreeMap<>(); + + for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { + final StageId stageId = stageDef.getId(); + retVal.computeIfAbsent(stageId, ignored -> new TreeSet<>()); + + for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { + final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); + retVal.computeIfAbsent(stageId, ignored -> new TreeSet<>()).add(inputStageId); + } + } + + return retVal; + } + + /** + * Returns a mapping of stage -> stages that depend on that stage. Uses TreeMaps and TreeSets so we have a consistent + * order for analyzing and running stages. + */ + public static Map> computeStageOutflowMap(final QueryDefinition queryDefinition) + { + final Map> retVal = new TreeMap<>(); + + for (final StageDefinition stageDef : queryDefinition.getStageDefinitions()) { + final StageId stageId = stageDef.getId(); + retVal.computeIfAbsent(stageId, ignored -> new TreeSet<>()); + + for (final int inputStageNumber : queryDefinition.getStageDefinition(stageId).getInputStageNumbers()) { + final StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber); + retVal.computeIfAbsent(inputStageId, ignored -> new TreeSet<>()).add(stageId); + } + } + + return retVal; + } + + /** + * Whether output of a stage can possibly use {@link OutputChannelMode#MEMORY}. Returning true does not guarantee + * that the stage *will* use {@link OutputChannelMode#MEMORY}. Additional requirements are checked in + * {@link #computeStageGroups(QueryDefinition, ControllerQueryKernelConfig)}. + */ + public static boolean canUseMemoryOutput( + final QueryDefinition queryDefinition, + final int stageNumber, + final ControllerQueryKernelConfig config, + final Map> outflowMap + ) + { + if (config.isFaultTolerant()) { + // Cannot use MEMORY output mode if fault tolerance is enabled: durable storage is required. + return false; + } + + if (!config.isPipeline() || config.getMaxConcurrentStages() < 2) { + // Cannot use MEMORY output mode if pipelining (& running multiple stages at once) is not enabled. + return false; + } + + final StageId stageId = queryDefinition.getStageDefinition(stageNumber).getId(); + final Set outflowStageIds = outflowMap.get(stageId); + + if (outflowStageIds.isEmpty()) { + return true; + } else if (outflowStageIds.size() == 1) { + final StageDefinition outflowStageDef = + queryDefinition.getStageDefinition(Iterables.getOnlyElement(outflowStageIds)); + + // Two things happening here: + // 1) Stages cannot use output mode MEMORY when broadcasting. This is because when using output mode MEMORY, we + // can only support a single reader. + // 2) Downstream stages can only have a single input stage with output mode MEMORY. This isn't strictly + // necessary, but it simplifies the logic around concurrently launching stages. + return stageId.equals(getOnlyNonBroadcastInputAsStageId(outflowStageDef)); + } else { + return false; + } + } + + /** + * Return an {@link OutputChannelMode} to use for a given stage, based on query and context. + */ + public static OutputChannelMode getOutputChannelMode( + final QueryDefinition queryDef, + final int stageNumber, + @Nullable final MSQSelectDestination selectDestination, + final boolean durableStorage, + final boolean canStream + ) + { + final boolean isFinalStage = queryDef.getFinalStageDefinition().getStageNumber() == stageNumber; + + if (isFinalStage && selectDestination == MSQSelectDestination.DURABLESTORAGE) { + return OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS; + } else if (canStream) { + return OutputChannelMode.MEMORY; + } else if (durableStorage) { + return OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE; + } else { + return OutputChannelMode.LOCAL_STORAGE; + } + } + + /** + * If a stage has a single non-broadcast input stage, returns that input stage. Otherwise, returns null. + * This is a helper used by {@link #canUseMemoryOutput}. + */ + @Nullable + public static StageId getOnlyNonBroadcastInputAsStageId(final StageDefinition downstreamStageDef) + { + final List inputSpecs = downstreamStageDef.getInputSpecs(); + final IntSet broadcastInputNumbers = downstreamStageDef.getBroadcastInputNumbers(); + + if (inputSpecs.size() - broadcastInputNumbers.size() != 1) { + return null; + } + + for (int i = 0; i < inputSpecs.size(); i++) { + if (!broadcastInputNumbers.contains(i)) { + final IntSet stageNumbers = InputSpecs.getStageNumbers(Collections.singletonList(inputSpecs.get(i))); + if (stageNumbers.size() == 1) { + return new StageId(downstreamStageDef.getId().getQueryId(), stageNumbers.iterator().nextInt()); + } + } + } + + return null; + } + + /** + * Remove all outflow from "stageId". At the conclusion of this method, "outflow" has an empty set for "stageId", + * and no sets in "inflow" reference "stageId". Outflow and inflow sets may become empty as a result of this + * operation. Sets that become empty are left empty, not removed from the inflow and outflow maps. + */ + private static void removeStageFlow( + final StageId stageId, + final Map> inflow, + final Map> outflow + ) + { + for (final StageId outStageId : outflow.get(stageId)) { + inflow.get(outStageId).remove(stageId); + } + + outflow.get(stageId).clear(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java index 3f8f3d19b3f8..eb124ab5b9f3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStagePhase.java @@ -19,9 +19,12 @@ package org.apache.druid.msq.kernel.controller; -import com.google.common.collect.ImmutableSet; - -import java.util.Set; +import org.apache.druid.msq.exec.ClusterStatisticsMergeMode; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.ShuffleKind; +import org.apache.druid.msq.kernel.ShuffleSpec; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; /** * Phases that a stage can be in, as far as the controller is concerned. @@ -30,7 +33,12 @@ */ public enum ControllerStagePhase { - // Not doing anything yet. Just recently initialized. + /** + * Not doing anything yet. Just recently initialized. + * + * When using {@link OutputChannelMode#MEMORY}, entering this phase tells us that it is time to launch the consumer + * stage (see {@link ControllerQueryKernel#readyToReadResults}). + */ NEW { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -39,7 +47,12 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Reading and mapping inputs (using "stateless" operators like filters, transforms which operate on individual records). + /** + * Reading inputs. + * + * Stages may transition directly from here to {@link #RESULTS_READY}, or they may go through + * {@link #MERGING_STATISTICS} and {@link #POST_READING}, depending on the {@link ShuffleSpec}. + */ READING_INPUT { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -48,12 +61,16 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Waiting to fetch key statistics in the background from the workers and incrementally generate partitions. - // This phase is only transitioned to once all partialKeyInformation are received from workers. - // Transitioning to this phase should also enqueue the task to fetch key statistics if `SEQUENTIAL` strategy is used. - // In `PARALLEL` strategy, we start fetching the key statistics as soon as they are available on the worker. - // This stage is not required in non-pre shuffle contexts - + /** + * Waiting to fetch key statistics in the background from the workers and incrementally generate partitions. + * + * This phase is only transitioned to once all {@link PartialKeyStatisticsInformation} are received from workers. + * Transitioning to this phase should also enqueue the task to fetch key statistics if + * {@link ClusterStatisticsMergeMode#SEQUENTIAL} strategy is used. In {@link ClusterStatisticsMergeMode#PARALLEL} + * strategy, we start fetching the key statistics as soon as they are available on the worker. + * + * This stage is used if, and only if, {@link StageDefinition#mustGatherResultKeyStatistics()}. + */ MERGING_STATISTICS { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -62,18 +79,29 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Post the inputs have been read and mapped to frames, in the `POST_READING` stage, we pre-shuffle and determining the partition boundaries. - // This step for a stage spits out the statistics of the data as a whole (and not just the individual records). This - // phase is not required in non-pre shuffle contexts. + /** + * Inputs have been completely read, and sorting is in progress. + * + * When using {@link OutputChannelMode#MEMORY} with {@link StageDefinition#doesSortDuringShuffle()}, entering this + * phase tells us that it is time to launch the consumer stage (see {@link ControllerQueryKernel#readyToReadResults}). + * + * This phase is only used when {@link ShuffleKind#isSort()}. Note that it may not *always* be used even when sorting; + * for example, when not using {@link OutputChannelMode#MEMORY} and also not gathering statistics + * ({@link StageDefinition#mustGatherResultKeyStatistics()}), a stage phase may transition directly from + * {@link #READING_INPUT} to {@link #RESULTS_READY}. + */ POST_READING { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) { - return priorPhase == MERGING_STATISTICS; + return priorPhase == READING_INPUT /* when sorting locally */ + || priorPhase == MERGING_STATISTICS /* when sorting globally */; } }, - // Done doing work and all results have been generated. + /** + * Done doing work, and all results have been generated. + */ RESULTS_READY { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -82,17 +110,25 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // The worker outputs for this stage might have been cleaned up in the workers, and they cannot be used by - // any other phase. "Metadata" for the stage such as counters are still available however + /** + * Stage has completed successfully and has been cleaned up. Worker outputs for this stage are no longer + * available and cannot be used by any other stage. Metadata such as counters are still available. + * + * Any non-terminal phase can transition to FINISHED. This can even happen prior to RESULTS_READY, if the + * controller determines that the outputs of the stage are no longer needed. For example, this happens when + * a downstream consumer is reading with limit, and decides it's finished processing. + */ FINISHED { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) { - return priorPhase == RESULTS_READY; + return !priorPhase.isTerminal(); } }, - // Something went wrong. + /** + * Something went wrong. + */ FAILED { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -101,9 +137,11 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) } }, - // Stages whose workers are currently under relaunch. We can transition out of Retrying state only when all the work orders - // of this stage have been sent. - // We can transition into Retrying phase when the prior phase did not publish its final results yet. + /** + * Stages whose workers are currently under relaunch. We can transition out of this phase only when all the work + * orders of this stage have been sent. We can transition into this phase when the prior phase did not + * publish its final results yet. + */ RETRYING { @Override public boolean canTransitionFrom(final ControllerStagePhase priorPhase) @@ -117,30 +155,40 @@ public boolean canTransitionFrom(final ControllerStagePhase priorPhase) public abstract boolean canTransitionFrom(ControllerStagePhase priorPhase); - private static final Set TERMINAL_PHASES = ImmutableSet.of( - RESULTS_READY, - FINISHED - ); + /** + * Whether this phase indicates that the stage has been started and is still running. (It hasn't been cleaned up + * or failed yet.) + */ + public boolean isRunning() + { + return this == READING_INPUT + || this == MERGING_STATISTICS + || this == POST_READING + || this == RESULTS_READY + || this == RETRYING; + } /** - * @return true if the phase indicates that the stage has completed its work and produced results successfully + * Whether this phase indicates that the stage has consumed its inputs from the previous stages successfully. */ - public static boolean isSuccessfulTerminalPhase(final ControllerStagePhase phase) + public boolean isDoneReadingInput() { - return TERMINAL_PHASES.contains(phase); + return this == POST_READING || this == RESULTS_READY || this == FINISHED; } - private static final Set POST_READING_PHASES = ImmutableSet.of( - POST_READING, - RESULTS_READY, - FINISHED - ); + /** + * Whether this phase indicates that the stage has completed its work and produced results successfully. + */ + public boolean isSuccess() + { + return this == RESULTS_READY || this == FINISHED; + } /** - * @return true if the phase indicates that the stage has consumed its inputs from the previous stages successfully + * Whether this phase indicates that the stage is no longer running. */ - public static boolean isPostReadingPhase(final ControllerStagePhase phase) + public boolean isTerminal() { - return POST_READING_PHASES.contains(phase); + return this == FINISHED || this == FAILED; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java index e0190bfacb34..0a62ba24b639 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java @@ -26,11 +26,13 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.ClusterStatisticsMergeMode; @@ -171,7 +173,14 @@ static ControllerStageTracker create( final long maxInputBytesPerWorker ) { - final WorkerInputs workerInputs = WorkerInputs.create(stageDef, stageWorkerCountMap, slicer, assignmentStrategy, maxInputBytesPerWorker); + final WorkerInputs workerInputs = WorkerInputs.create( + stageDef, + stageWorkerCountMap, + slicer, + assignmentStrategy, + maxInputBytesPerWorker + ); + return new ControllerStageTracker( stageDef, workerInputs, @@ -331,12 +340,15 @@ boolean collectorEncounteredAnyMultiValueField() */ Object getResultObject() { - if (phase == ControllerStagePhase.FINISHED) { - throw new ISE("Result object has been cleaned up prematurely"); - } else if (phase != ControllerStagePhase.RESULTS_READY) { - throw new ISE("Result object is not ready yet"); + if (!phase.isSuccess()) { + throw new ISE("Result object for stage[%s] is not ready yet", stageDef.getId()); } else if (resultObject == null) { - throw new NullPointerException("resultObject was unexpectedly null"); + throw new NullPointerException( + StringUtils.format( + "Result object for stage[%s] was unexpectedly null", + stageDef.getId() + ) + ); } else { return resultObject; } @@ -382,7 +394,7 @@ public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation() * @param workerNumber the worker * @param partialKeyStatisticsInformation partial key statistics */ - ControllerStagePhase addPartialKeyInformationForWorker( + void addPartialKeyInformationForWorker( final int workerNumber, final PartialKeyStatisticsInformation partialKeyStatisticsInformation ) @@ -412,7 +424,7 @@ ControllerStagePhase addPartialKeyInformationForWorker( if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) { // Time should not contain null value failForReason(InsertTimeNullFault.instance()); - return getPhase(); + return; } completeKeyStatisticsInformation.mergePartialInformation(workerNumber, partialKeyStatisticsInformation); } @@ -470,7 +482,6 @@ ControllerStagePhase addPartialKeyInformationForWorker( fail(); throw e; } - return getPhase(); } private void initializeTimeChunkWorkerTrackers() @@ -502,7 +513,6 @@ private void initializeTimeChunkWorkerTrackers() *

* If all the stats from all the workers are merged, we transition the stage to {@link ControllerStagePhase#POST_READING} */ - void mergeClusterByStatisticsCollectorForTimeChunk( int workerNumber, Long timeChunk, @@ -762,6 +772,58 @@ void setClusterByPartitionBoundaries(ClusterByPartitions clusterByPartitions) transitionTo(ControllerStagePhase.POST_READING); } + /** + * Transitions phase directly from {@link ControllerStagePhase#READING_INPUT} to + * {@link ControllerStagePhase#POST_READING}, skipping {@link ControllerStagePhase#MERGING_STATISTICS}. + * This method is used for stages that sort but do not need to gather result key statistics. + */ + void setDoneReadingInputForWorker(final int workerNumber) + { + if (stageDef.mustGatherResultKeyStatistics()) { + throw DruidException.defensive( + "Cannot setDoneReadingInput for stage[%s], it should send partial key information instead", + stageDef.getId() + ); + } + + if (!stageDef.doesSortDuringShuffle()) { + throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId()); + } + + if (workerNumber < 0 || workerNumber >= workerCount) { + throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId()); + } + + ControllerWorkerStagePhase currentPhase = workerToPhase.get(workerNumber); + + if (currentPhase == null) { + throw new ISE("Worker[%d] not found for stage[%s]", workerNumber, stageDef.getId()); + } + + try { + if (ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT.canTransitionFrom(currentPhase)) { + workerToPhase.put(workerNumber, ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT); + + if (allWorkersDoneReadingInput()) { + transitionTo(ControllerStagePhase.POST_READING); + } + } else { + throw new ISE( + "Worker[%d] for stage[%d] expected to be in phase that can transition to[%s]. Found phase[%s]", + workerNumber, + stageDef.getStageNumber(), + ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT, + currentPhase + ); + } + } + catch (Exception e) { + // If this op fails, we're in an inconsistent state and must cancel the stage. + fail(); + throw e; + } + } + /** * Accepts and sets the results that each worker produces for this particular stage * @@ -937,6 +999,21 @@ public boolean allPartialKeyInformationFetched() == workerCount; } + /** + * True if all workers are done reading their inputs. + */ + public boolean allWorkersDoneReadingInput() + { + for (final ControllerWorkerStagePhase phase : workerToPhase.values()) { + if (phase != ControllerWorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT + && phase != ControllerWorkerStagePhase.RESULTS_READY) { + return false; + } + } + + return true; + } + /** * True if all {@link org.apache.druid.msq.kernel.WorkOrder} are sent else false. */ @@ -973,7 +1050,7 @@ private void transitionTo(final ControllerStagePhase newPhase) if (newPhase.canTransitionFrom(phase)) { phase = newPhase; } else { - throw new IAE("Cannot transition from [%s] to [%s]", phase, newPhase); + throw new IAE("Cannot transition stage[%s] from[%s] to[%s]", stageDef.getId(), phase, newPhase); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java index 1c3e370dc80e..89eca8f83755 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerWorkerStagePhase.java @@ -69,7 +69,8 @@ public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final ControllerWorkerStagePhase priorPhase) { - return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES; + return priorPhase == READING_INPUT /* when sorting locally */ + || priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES /* when sorting globally */; } }, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java new file mode 100644 index 000000000000..f58eb3ee9c39 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/StageGroup.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.StageId; + +import java.util.List; +import java.util.Objects; + +/** + * Group of stages that must be launched as a unit. Within each group, stages communicate with each other using + * {@link OutputChannelMode#MEMORY} channels. The final stage in a group writes its own output using + * {@link #lastStageOutputChannelMode()}. + * + * Stages in a group have linear (non-branching) data flow: the first stage is an input to the second stage, the second + * stage is an input to the third stage, and so on. This is done to simplify the logic. In the future, it is possible + * that stage groups may contain branching data flow. + */ +public class StageGroup +{ + private final List stageIds; + private final OutputChannelMode groupOutputChannelMode; + + public StageGroup(final List stageIds, final OutputChannelMode groupOutputChannelMode) + { + this.stageIds = stageIds; + this.groupOutputChannelMode = groupOutputChannelMode; + } + + /** + * List of stage IDs in this group. + * + * The first stage is an input to the second stage, the second stage is an input to the third stage, and so on. + * See class-level javadocs for more details. + */ + public List stageIds() + { + return stageIds; + } + + /** + * Output mode of the final stage in this group. + */ + public OutputChannelMode lastStageOutputChannelMode() + { + return stageOutputChannelMode(last()); + } + + /** + * Output mode of the given stage. + */ + public OutputChannelMode stageOutputChannelMode(final StageId stageId) + { + if (last().equals(stageId)) { + return groupOutputChannelMode; + } else if (stageIds.contains(stageId)) { + return OutputChannelMode.MEMORY; + } else { + throw new IAE("Stage[%s] not in group", stageId); + } + } + + /** + * First stage in this group. + */ + public StageId first() + { + return stageIds.get(0); + } + + /** + * Last stage in this group. + */ + public StageId last() + { + return stageIds.get(stageIds.size() - 1); + } + + /** + * Number of stages in this group. + */ + public int size() + { + return stageIds.size(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + StageGroup that = (StageGroup) o; + return Objects.equals(stageIds, that.stageIds) && groupOutputChannelMode == that.groupOutputChannelMode; + } + + @Override + public int hashCode() + { + return Objects.hash(stageIds, groupOutputChannelMode); + } + + @Override + public String toString() + { + return "StageGroup{" + + "stageIds=" + stageIds + + ", groupOutputChannelMode=" + groupOutputChannelMode + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java index ed0807475ef6..09c48aa942c9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java @@ -125,8 +125,7 @@ public ProcessorsAndChannels makeProcessors( final OutputChannel outputChannel = outputChannelFactory.openChannel(0 /* Partition number doesn't matter */); outputChannels.add(outputChannel); channelQueue.add(outputChannel.getWritableChannel()); - frameWriterFactoryQueue.add(stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()) - ); + frameWriterFactoryQueue.add(stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator())); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index c532dcee56e8..831c9b139d3d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -555,7 +555,7 @@ private static DataSourcePlan forUnion( // This is done to prevent loss of generality since MSQ can plan any type of DataSource. List children = unionDataSource.getDataSources(); - final QueryDefinitionBuilder subqueryDefBuilder = QueryDefinition.builder(); + final QueryDefinitionBuilder subqueryDefBuilder = QueryDefinition.builder(queryId); final List newChildren = new ArrayList<>(); final List inputSpecs = new ArrayList<>(); final IntSet broadcastInputs = new IntOpenHashSet(); @@ -605,7 +605,7 @@ private static DataSourcePlan forBroadcastHashJoin( final boolean broadcast ) { - final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(); + final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(queryId); final DataSourceAnalysis analysis = dataSource.getAnalysis(); final DataSourcePlan basePlan = forDataSource( @@ -683,7 +683,7 @@ private static DataSourcePlan forSortMergeJoin( SortMergeJoinFrameProcessorFactory.validateCondition(dataSource.getConditionAnalysis()) ); - final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(); + final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder(queryId); // Plan the left input. // We're confident that we can cast dataSource.getLeft() to QueryDataSource, because DruidJoinQueryRel creates diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java index b9c0f1a0d262..85c0c14e16e9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactory.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.querykit; import org.apache.druid.frame.key.ClusterBy; +import org.apache.druid.msq.kernel.GlobalSortShuffleSpec; import org.apache.druid.msq.kernel.ShuffleSpec; /** @@ -29,7 +30,7 @@ public interface ShuffleSpecFactory { /** * Build a {@link ShuffleSpec} for given {@link ClusterBy}. The {@code aggregate} flag is used to populate - * {@link ShuffleSpec#doesAggregate()}. + * {@link GlobalSortShuffleSpec#doesAggregate()}. */ ShuffleSpec build(ClusterBy clusterBy, boolean aggregate); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java index 0bbe8eb91aed..d08d78ef791f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java @@ -76,7 +76,7 @@ public QueryDefinition makeQueryDefinition( ShuffleSpec nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(0), maxWorkerCount); // add this shuffle spec to the last stage of the inner query - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder().queryId(queryId); + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); if (nextShuffleSpec != null) { final ClusterBy windowClusterBy = nextShuffleSpec.clusterBy(); originalQuery = (WindowOperatorQuery) originalQuery.withOverriddenContext(ImmutableMap.of( @@ -178,7 +178,7 @@ public QueryDefinition makeQueryDefinition( ); } } - return queryDefBuilder.queryId(queryId).build(); + return queryDefBuilder.build(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java index 96b4b77f159b..f02e505d0c5a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java @@ -78,7 +78,7 @@ public QueryDefinition makeQueryDefinition( { validateQuery(originalQuery); - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder().queryId(queryId); + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( queryKit, queryId, @@ -240,7 +240,7 @@ public QueryDefinition makeQueryDefinition( } } - return queryDefBuilder.queryId(queryId).build(); + return queryDefBuilder.build(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java index 8bc6f0bfa96d..2927264382a4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java @@ -92,7 +92,7 @@ public QueryDefinition makeQueryDefinition( final int minStageNumber ) { - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder().queryId(queryId); + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( queryKit, queryId, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java new file mode 100644 index 000000000000..92042d59a8a8 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.channel.ReadableByteChunksFrameChannel; +import org.apache.druid.frame.file.FrameFileHttpResponseHandler; +import org.apache.druid.frame.file.FrameFilePartialFetch; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHolder; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.rpc.IgnoreHttpResponseHandler; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import javax.annotation.Nonnull; +import javax.ws.rs.core.HttpHeaders; +import java.io.IOException; + +/** + * Base worker client. Subclasses override {@link #getClient(String)} and {@link #close()} to build a complete client + * for talking to specific types of workers. + */ +public abstract class BaseWorkerClientImpl implements WorkerClient +{ + private final ObjectMapper objectMapper; + private final String contentType; + + protected BaseWorkerClientImpl(final ObjectMapper objectMapper, final String contentType) + { + this.objectMapper = objectMapper; + this.contentType = contentType; + } + + @Nonnull + public static String getStagePartitionPath(StageId stageId, int partitionNumber) + { + return StringUtils.format( + "/channels/%s/%d/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber(), + partitionNumber + ); + } + + @Override + public ListenableFuture postWorkOrder(String workerId, WorkOrder workOrder) + { + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, "/workOrder") + .objectContent(objectMapper, contentType, workOrder), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + @Override + public ListenableFuture fetchClusterByStatisticsSnapshot( + String workerId, + StageId stageId + ) + { + String path = StringUtils.format( + "/keyStatistics/%s/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber() + ); + + return FutureUtils.transform( + getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path).header(HttpHeaders.ACCEPT, contentType), + new BytesFullResponseHandler() + ), + holder -> deserialize(holder, new TypeReference() {}) + ); + } + + @Override + public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( + String workerId, + StageId stageId, + long timeChunk + ) + { + String path = StringUtils.format( + "/keyStatisticsForTimeChunk/%s/%d/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber(), + timeChunk + ); + + return FutureUtils.transform( + getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path).header(HttpHeaders.ACCEPT, contentType), + new BytesFullResponseHandler() + ), + holder -> deserialize(holder, new TypeReference() {}) + ); + } + + @Override + public ListenableFuture postResultPartitionBoundaries( + String workerId, + StageId stageId, + ClusterByPartitions partitionBoundaries + ) + { + final String path = StringUtils.format( + "/resultPartitionBoundaries/%s/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber() + ); + + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path) + .objectContent(objectMapper, contentType, partitionBoundaries), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + /** + * Client-side method for {@link org.apache.druid.msq.indexing.client.WorkerChatHandler#httpPostCleanupStage}. + */ + @Override + public ListenableFuture postCleanupStage( + final String workerId, + final StageId stageId + ) + { + final String path = StringUtils.format( + "/cleanupStage/%s/%d", + StringUtils.urlEncode(stageId.getQueryId()), + stageId.getStageNumber() + ); + + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, path), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + @Override + public ListenableFuture postFinish(String workerId) + { + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, "/finish"), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + @Override + public ListenableFuture getCounters(String workerId) + { + return FutureUtils.transform( + getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.GET, "/counters").header(HttpHeaders.ACCEPT, contentType), + new BytesFullResponseHandler() + ), + holder -> deserialize(holder, new TypeReference() {}) + ); + } + + private static final Logger log = new Logger(BaseWorkerClientImpl.class); + + @Override + public ListenableFuture fetchChannelData( + String workerId, + StageId stageId, + int partitionNumber, + long offset, + ReadableByteChunksFrameChannel channel + ) + { + final ServiceClient client = getClient(workerId); + final String path = getStagePartitionPath(stageId, partitionNumber); + + final SettableFuture retVal = SettableFuture.create(); + final ListenableFuture clientFuture = + client.asyncRequest( + new RequestBuilder(HttpMethod.GET, StringUtils.format("%s?offset=%d", path, offset)) + .header(HttpHeaders.ACCEPT_ENCODING, "identity"), // Data is compressed at app level + new FrameFileHttpResponseHandler(channel) + ); + + Futures.addCallback( + clientFuture, + new FutureCallback() + { + @Override + public void onSuccess(FrameFilePartialFetch partialFetch) + { + if (partialFetch.isExceptionCaught()) { + // Exception while reading channel. Recoverable. + log.noStackTrace().info( + partialFetch.getExceptionCaught(), + "Encountered exception while reading channel [%s]", + channel.getId() + ); + } + + // Empty fetch means this is the last fetch for the channel. + partialFetch.backpressureFuture().addListener( + () -> retVal.set(partialFetch.isLastFetch()), + Execs.directExecutor() + ); + } + + @Override + public void onFailure(Throwable t) + { + retVal.setException(t); + } + }, + Execs.directExecutor() + ); + + return retVal; + } + + /** + * Create a client to communicate with a given worker ID. + */ + protected abstract ServiceClient getClient(String workerId); + + /** + * Deserialize a {@link BytesFullResponseHolder} as JSON. + * + * It would be reasonable to move this to {@link BytesFullResponseHolder} itself, or some shared utility class. + */ + protected T deserialize(final BytesFullResponseHolder bytesHolder, final TypeReference typeReference) + { + try { + return objectMapper.readValue(bytesHolder.getContent(), typeReference); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java new file mode 100644 index 000000000000..d3e9eefa86d2 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.msq.counters.CounterSnapshots; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.MSQTaskList; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.server.security.AuthorizerMapper; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.List; + +public class ControllerResource +{ + private final Controller controller; + private final ResourcePermissionMapper permissionMapper; + private final AuthorizerMapper authorizerMapper; + + public ControllerResource( + final Controller controller, + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper + ) + { + this.controller = controller; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + } + + /** + * Used by subtasks to post {@link PartialKeyStatisticsInformation} for shuffling stages. + * + * See {@link ControllerClient#postPartialKeyStatistics(StageId, int, PartialKeyStatisticsInformation)} + * for the client-side code that calls this API. + */ + @POST + @Path("/partialKeyStatisticsInformation/{queryId}/{stageNumber}/{workerNumber}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostPartialKeyStatistics( + final Object partialKeyStatisticsObject, + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("workerNumber") final int workerNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.updatePartialKeyStatisticsInformation(stageNumber, workerNumber, partialKeyStatisticsObject); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Used by subtasks to post system errors. Note that the errors are organized by taskId, not by query/stage/worker, + * because system errors are associated with a task rather than a specific query/stage/worker execution context. + * + * See {@link ControllerClient#postWorkerError} for the client-side code that calls this API. + */ + @POST + @Path("/workerError/{taskId}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostWorkerError( + final MSQErrorReport errorReport, + @PathParam("taskId") final String taskId, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.workerError(errorReport); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Used by subtasks to post system warnings. + * + * See {@link ControllerClient#postWorkerWarning} for the client-side code that calls this API. + */ + @POST + @Path("/workerWarning") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostWorkerWarning( + final List errorReport, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.workerWarning(errorReport); + return Response.status(Response.Status.ACCEPTED).build(); + } + + + /** + * Used by subtasks to post {@link CounterSnapshots} periodically. + * + * See {@link ControllerClient#postCounters} for the client-side code that calls this API. + */ + @POST + @Path("/counters/{taskId}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostCounters( + @PathParam("taskId") final String taskId, + final CounterSnapshotsTree snapshotsTree, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.updateCounters(taskId, snapshotsTree); + return Response.status(Response.Status.OK).build(); + } + + /** + * Used by subtasks to post notifications that their results are ready. + * + * See {@link ControllerClient#postResultsComplete} for the client-side code that calls this API. + */ + @POST + @Path("/resultsComplete/{queryId}/{stageNumber}/{workerNumber}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostResultsComplete( + final Object resultObject, + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("workerNumber") final int workerNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.resultsComplete(queryId, stageNumber, workerNumber, resultObject); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link ControllerClient#getTaskList()} for the client-side code that calls this API. + */ + @GET + @Path("/taskList") + @Produces(MediaType.APPLICATION_JSON) + public Response httpGetTaskList(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return Response.ok(new MSQTaskList(controller.getTaskIds())).build(); + } + + /** + * See {@link org.apache.druid.indexing.overlord.RemoteTaskRunner#streamTaskReports} for the client-side code that + * calls this API. + */ + @GET + @Path("/liveReports") + @Produces(MediaType.APPLICATION_JSON) + public Response httpGetLiveReports(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + final TaskReport.ReportMap reports = controller.liveReports(); + if (reports == null) { + return Response.status(Response.Status.NOT_FOUND).build(); + } + return Response.ok(reports).build(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java new file mode 100644 index 000000000000..30a8179fe0f0 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import org.apache.druid.server.security.Access; +import org.apache.druid.server.security.AuthorizationUtils; +import org.apache.druid.server.security.AuthorizerMapper; +import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.server.security.ResourceAction; + +import javax.servlet.http.HttpServletRequest; +import java.util.List; + +/** + * Utility methods for MSQ resources such as {@link ControllerResource}. + */ +public class MSQResourceUtils +{ + public static void authorizeAdminRequest( + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper, + final HttpServletRequest request + ) + { + final List resourceActions = permissionMapper.getAdminPermissions(); + + Access access = AuthorizationUtils.authorizeAllResourceActions(request, resourceActions, authorizerMapper); + + if (!access.isAllowed()) { + throw new ForbiddenException(access.toString()); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/guice/annotations/MSQ.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java similarity index 59% rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/guice/annotations/MSQ.java rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java index c480168de258..8c79f4fa0e05 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/guice/annotations/MSQ.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java @@ -17,24 +17,17 @@ * under the License. */ -package org.apache.druid.guice.annotations; +package org.apache.druid.msq.rpc; -import com.google.inject.BindingAnnotation; +import org.apache.druid.server.security.ResourceAction; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; +import java.util.List; /** - * Binding annotation for implements of interfaces that are MSQ (MultiStageQuery) focused. This is generally - * contrasted with the NativeQ annotation. - * - * @see Parent + * Provides HTTP resources such as {@link ControllerResource} with information about which permissions are needed + * for requests. */ -@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.METHOD}) -@Retention(RetentionPolicy.RUNTIME) -@BindingAnnotation -public @interface MSQ +public interface ResourcePermissionMapper { + List getAdminPermissions(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java index c0e892b99bf0..f913dbb1858a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/DurableStorageInputChannelFactory.java @@ -25,6 +25,7 @@ import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel; import org.apache.druid.java.util.common.IOE; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; @@ -73,8 +74,11 @@ public static DurableStorageInputChannelFactory createStandardImplementation( final boolean isQueryResults ) { + final String threadNameFormat = + StringUtils.encodeForFormat(Preconditions.checkNotNull(controllerTaskId, "controllerTaskId")) + + "-remote-fetcher-%d"; final ExecutorService remoteInputStreamPool = - Executors.newCachedThreadPool(Execs.makeThreadFactory(controllerTaskId + "-remote-fetcher-%d")); + Executors.newCachedThreadPool(Execs.makeThreadFactory(threadNameFormat)); closer.register(remoteInputStreamPool::shutdownNow); if (isQueryResults) { return new DurableStorageQueryResultsInputChannelFactory( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java index cc360a48ede2..4beb2a869ef0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java @@ -37,7 +37,6 @@ import org.apache.druid.error.NotFound; import org.apache.druid.error.QueryExceptionCompat; import org.apache.druid.frame.channel.FrameChannelSequence; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.indexer.TaskStatusPlus; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.RE; @@ -48,6 +47,7 @@ import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.guice.MultiStageQuery; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.MSQSpec; @@ -131,7 +131,7 @@ public class SqlStatementResource @Inject public SqlStatementResource( - final @MSQ SqlStatementFactory msqSqlStatementFactory, + final @MultiStageQuery SqlStatementFactory msqSqlStatementFactory, final ObjectMapper jsonMapper, final OverlordClient overlordClient, final @MultiStageQuery StorageConnector storageConnector, @@ -540,27 +540,9 @@ private Optional getResultSetInformation( List results = null; if (isSelectQuery) { results = new ArrayList<>(); - Yielder yielder = null; if (msqTaskReportPayload.getResults() != null) { - yielder = msqTaskReportPayload.getResults().getResultYielder(); + results = msqTaskReportPayload.getResults().getResults(); } - try { - while (yielder != null && !yielder.isDone()) { - results.add(yielder.get()); - yielder = yielder.next(null); - } - } - finally { - if (yielder != null) { - try { - yielder.close(); - } - catch (IOException e) { - log.warn(e, StringUtils.format("Unable to close yielder for query[%s]", queryId)); - } - } - } - } return Optional.of( @@ -739,10 +721,10 @@ private Optional> getResultYielder( contactOverlord(overlordClient.taskReportAsMap(queryId), queryId) ); - if (msqTaskReportPayload.getResults().getResultYielder() == null) { + if (msqTaskReportPayload.getResults().getResults() == null) { results = Optional.empty(); } else { - results = Optional.of(msqTaskReportPayload.getResults().getResultYielder()); + results = Optional.of(Yielders.each(Sequences.simple(msqTaskReportPayload.getResults().getResults()))); } } else if (msqControllerTask.getQuerySpec().getDestination() instanceof DurableStorageMSQDestination) { @@ -801,12 +783,17 @@ private Optional> getResultYielder( } }) .collect(Collectors.toList())) - .flatMap(frame -> SqlStatementResourceHelper.getResultSequence( - msqControllerTask, - finalStage, - frame, - jsonMapper - ) + .flatMap(frame -> + SqlStatementResourceHelper.getResultSequence( + frame, + finalStage.getFrameReader(), + msqControllerTask.getQuerySpec().getColumnMappings(), + new ResultsContext( + msqControllerTask.getSqlTypeNames(), + msqControllerTask.getSqlResultsContext() + ), + jsonMapper + ) ) .withBaggage(closer))); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java index e9b4c61cef23..7a51bc8d26a4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java @@ -26,12 +26,12 @@ import org.apache.druid.error.DruidException; import org.apache.druid.error.ErrorResponse; import org.apache.druid.error.QueryExceptionCompat; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.indexer.TaskState; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.guice.MultiStageQuery; import org.apache.druid.msq.sql.MSQTaskSqlEngine; import org.apache.druid.msq.sql.SqlTaskStatus; import org.apache.druid.query.QueryException; @@ -86,7 +86,7 @@ public class SqlTaskResource @Inject public SqlTaskResource( - final @MSQ SqlStatementFactory sqlStatementFactory, + final @MultiStageQuery SqlStatementFactory sqlStatementFactory, final ServerConfig serverConfig, final AuthorizerMapper authorizerMapper, final ObjectMapper jsonMapper diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java index 535af8dafb0a..9a6e256d9add 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/PartialKeyStatisticsInformation.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Objects; import java.util.Set; /** @@ -64,4 +65,35 @@ public double getBytesRetained() { return bytesRetained; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartialKeyStatisticsInformation that = (PartialKeyStatisticsInformation) o; + return multipleValues == that.multipleValues + && Double.compare(bytesRetained, that.bytesRetained) == 0 + && Objects.equals(timeSegments, that.timeSegments); + } + + @Override + public int hashCode() + { + return Objects.hash(timeSegments, multipleValues, bytesRetained); + } + + @Override + public String toString() + { + return "PartialKeyStatisticsInformation{" + + "timeSegments=" + timeSegments + + ", multipleValues=" + multipleValues + + ", bytesRetained=" + bytesRetained + + '}'; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 60734b5b1dad..4b599cd32d5b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -33,6 +33,7 @@ import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.indexing.destination.MSQSelectDestination; +import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.sql.MSQMode; import org.apache.druid.query.QueryContext; @@ -112,6 +113,8 @@ public class MultiStageQueryContext public static final String CTX_INCLUDE_SEGMENT_SOURCE = "includeSegmentSource"; public static final SegmentSource DEFAULT_INCLUDE_SEGMENT_SOURCE = SegmentSource.NONE; + public static final String CTX_MAX_CONCURRENT_STAGES = "maxConcurrentStages"; + public static final int DEFAULT_MAX_CONCURRENT_STAGES = 1; public static final String CTX_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage"; private static final boolean DEFAULT_DURABLE_SHUFFLE_STORAGE = false; public static final String CTX_SELECT_DESTINATION = "selectDestination"; @@ -173,6 +176,14 @@ public static String getMSQMode(final QueryContext queryContext) ); } + public static int getMaxConcurrentStages(final QueryContext queryContext) + { + return queryContext.getInt( + CTX_MAX_CONCURRENT_STAGES, + DEFAULT_MAX_CONCURRENT_STAGES + ); + } + public static boolean isDurableStorageEnabled(final QueryContext queryContext) { return queryContext.getBoolean( @@ -316,6 +327,14 @@ public static IndexSpec getIndexSpec(final QueryContext queryContext, final Obje return decodeIndexSpec(queryContext.get(CTX_INDEX_SPEC), objectMapper); } + public static long getMaxParseExceptions(final QueryContext queryContext) + { + return queryContext.getLong( + MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, + MSQWarnings.DEFAULT_MAX_PARSE_EXCEPTIONS_ALLOWED + ); + } + public static boolean useAutoColumnSchemas(final QueryContext queryContext) { return queryContext.getBoolean(CTX_USE_AUTO_SCHEMAS, DEFAULT_USE_AUTO_SCHEMAS); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java index f90959a56667..4f07dcb2cc02 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java @@ -29,6 +29,7 @@ import org.apache.druid.error.NotFound; import org.apache.druid.frame.Frame; import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.read.FrameReader; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatusPlus; @@ -42,6 +43,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.counters.QueryCounterSnapshot; import org.apache.druid.msq.counters.SegmentGenerationProgressCounter; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; @@ -52,7 +54,6 @@ import org.apache.druid.msq.indexing.report.MSQStagesReport; import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; -import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.sql.SqlStatementState; import org.apache.druid.msq.sql.entity.ColumnNameAndTypes; import org.apache.druid.msq.sql.entity.PageInformation; @@ -71,6 +72,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.stream.Collectors; @@ -296,25 +298,23 @@ protected DruidException makeException(DruidException.DruidExceptionBuilder bob) } public static Sequence getResultSequence( - MSQControllerTask msqControllerTask, - StageDefinition finalStage, - Frame frame, - ObjectMapper jsonMapper + final Frame resultsFrame, + final FrameReader resultFrameReader, + final ColumnMappings resultColumnMappings, + final ResultsContext resultsContext, + final ObjectMapper jsonMapper ) { - final Cursor cursor = FrameProcessors.makeCursor(frame, finalStage.getFrameReader()); - + final Cursor cursor = FrameProcessors.makeCursor(resultsFrame, resultFrameReader); final ColumnSelectorFactory columnSelectorFactory = cursor.getColumnSelectorFactory(); - final ColumnMappings columnMappings = msqControllerTask.getQuerySpec().getColumnMappings(); @SuppressWarnings("rawtypes") - final List selectors = columnMappings.getMappings() - .stream() - .map(mapping -> columnSelectorFactory.makeColumnValueSelector( - mapping.getQueryColumn())) - .collect(Collectors.toList()); - - final List sqlTypeNames = msqControllerTask.getSqlTypeNames(); - Iterable retVal = () -> new Iterator() + final List selectors = + resultColumnMappings.getMappings() + .stream() + .map(mapping -> columnSelectorFactory.makeColumnValueSelector(mapping.getQueryColumn())) + .collect(Collectors.toList()); + + final Iterable retVal = () -> new Iterator() { @Override public boolean hasNext() @@ -325,19 +325,23 @@ public boolean hasNext() @Override public Object[] next() { - final Object[] row = new Object[columnMappings.size()]; + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final Object[] row = new Object[resultColumnMappings.size()]; for (int i = 0; i < row.length; i++) { final Object value = selectors.get(i).getObject(); - if (sqlTypeNames == null || msqControllerTask.getSqlResultsContext() == null) { + if (resultsContext.getSqlTypeNames() == null || resultsContext.getSqlResultsContext() == null) { // SQL type unknown, or no SQL results context: pass-through as is. row[i] = value; } else { row[i] = SqlResults.coerce( jsonMapper, - msqControllerTask.getSqlResultsContext(), + resultsContext.getSqlResultsContext(), value, - sqlTypeNames.get(i), - columnMappings.getOutputColumnName(i) + resultsContext.getSqlTypeNames().get(i), + resultColumnMappings.getOutputColumnName(i) ); } } diff --git a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule index cabd131fb758..92be5604cb8a 100644 --- a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule +++ b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +org.apache.druid.msq.guice.IndexerMemoryManagementModule +org.apache.druid.msq.guice.MSQDurableStorageModule org.apache.druid.msq.guice.MSQExternalDataSourceModule org.apache.druid.msq.guice.MSQIndexingModule -org.apache.druid.msq.guice.MSQDurableStorageModule org.apache.druid.msq.guice.MSQSqlModule +org.apache.druid.msq.guice.PeonMemoryManagementModule org.apache.druid.msq.guice.SqlTaskModule diff --git a/extensions-core/multi-stage-query/src/main/resources/log4j2.xml b/extensions-core/multi-stage-query/src/main/resources/log4j2.xml index e99abd743366..d98bb05ef6cd 100644 --- a/extensions-core/multi-stage-query/src/main/resources/log4j2.xml +++ b/extensions-core/multi-stage-query/src/main/resources/log4j2.xml @@ -31,6 +31,9 @@ + + + diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java index 41c3cff66a50..db5ef1a089c2 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerImplTest.java @@ -72,6 +72,7 @@ public void test_performSegmentPublish_ok() throws IOException // All OK. ControllerImpl.performSegmentPublish(taskActionClient, action); + EasyMock.verify(taskActionClient); } @Test @@ -90,6 +91,7 @@ public void test_performSegmentPublish_publishFail() throws IOException ); Assert.assertEquals(InsertLockPreemptedFault.instance(), e.getFault()); + EasyMock.verify(taskActionClient); } @Test @@ -108,6 +110,7 @@ public void test_performSegmentPublish_publishException() throws IOException ); Assert.assertEquals("oops", e.getMessage()); + EasyMock.verify(taskActionClient); } @Test @@ -126,6 +129,7 @@ public void test_performSegmentPublish_publishLockPreemptedException() throws IO ); Assert.assertEquals(InsertLockPreemptedFault.instance(), e.getFault()); + EasyMock.verify(taskActionClient); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java new file mode 100644 index 000000000000..9d27dcca666b --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ControllerMemoryParametersTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.collect.ImmutableMap; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault; +import org.apache.druid.sql.calcite.util.TestLookupProvider; +import org.junit.Assert; +import org.junit.Test; + +public class ControllerMemoryParametersTest +{ + private static final double USABLE_MEMORY_FRACTION = 0.8; + private static final int NUM_PROCESSORS_IN_JVM = 2; + + @Test + public void test_oneQueryInJvm() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(128_000_000, 1), + 1 + ); + + Assert.assertEquals(100_400_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_oneQueryInJvm_oneHundredWorkers() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(256_000_000, 1), + 100 + ); + + Assert.assertEquals(103_800_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_twoQueriesInJvm() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(128_000_000, 2), + 1 + ); + + Assert.assertEquals(49_200_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_maxSized() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(1_000_000_000, 1), + 1 + ); + + Assert.assertEquals(300_000_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + @Test + public void test_notEnoughMemory() + { + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(30_000_000, 1), + 1 + ) + ); + + final NotEnoughMemoryFault fault = (NotEnoughMemoryFault) e.getFault(); + Assert.assertEquals(30_000_000, fault.getServerMemory()); + Assert.assertEquals(1, fault.getServerWorkers()); + Assert.assertEquals(NUM_PROCESSORS_IN_JVM, fault.getServerThreads()); + Assert.assertEquals(24_000_000, fault.getUsableMemory()); + Assert.assertEquals(33_750_000, fault.getSuggestedServerMemory()); + } + + @Test + public void test_minimalMemory() + { + final ControllerMemoryParameters memoryParameters = ControllerMemoryParameters.createProductionInstance( + makeMemoryIntrospector(33_750_000, 1), + 1 + ); + + Assert.assertEquals(25_000_000, memoryParameters.getPartitionStatisticsMaxRetainedBytes()); + } + + private MemoryIntrospector makeMemoryIntrospector( + final long totalMemoryInJvm, + final int numQueriesInJvm + ) + { + return new MemoryIntrospectorImpl( + new TestLookupProvider(ImmutableMap.of()), + totalMemoryInJvm, + USABLE_MEMORY_FRACTION, + numQueriesInJvm, + NUM_PROCESSORS_IN_JVM + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java index b3b1442074b7..425609628b3a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java @@ -27,6 +27,7 @@ import org.apache.druid.indexing.common.TaskLockType; import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; import org.apache.druid.indexing.common.actions.SegmentAllocateAction; +import org.apache.druid.indexing.common.actions.TaskAction; import org.apache.druid.indexing.common.task.Tasks; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; @@ -51,6 +52,7 @@ import org.hamcrest.CoreMatchers; import org.junit.internal.matchers.ThrowableMessageMatcher; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; @@ -532,7 +534,10 @@ public void testReplaceTombstonesWithTooManyBucketsThrowsFault() Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); String expectedError = new TooManyBucketsFault(Limits.MAX_PARTITION_BUCKETS).getErrorMessage(); @@ -578,7 +583,10 @@ public void testReplaceTombstonesWithTooManyBucketsThrowsFault2() Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); String expectedError = new TooManyBucketsFault(Limits.MAX_PARTITION_BUCKETS).getErrorMessage(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java index f05e35c304c0..ecdc30294dbd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java @@ -1343,7 +1343,7 @@ public void testInsertWithTooLargeRowShouldThrowException(String contextName, Ma final File toRead = getResourceAsTemporaryFile("/wikipedia-sampled.json"); final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath()); - Mockito.doReturn(500).when(workerMemoryParameters).getLargeFrameSize(); + Mockito.doReturn(500).when(workerMemoryParameters).getStandardFrameSize(); testIngestQuery().setSql(" insert into foo1 SELECT\n" + " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n" diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java index 9a4fb98666b3..227a9656a142 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java @@ -34,6 +34,7 @@ import org.apache.druid.indexer.partitions.PartitionsSpec; import org.apache.druid.indexing.common.TaskLockType; import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; +import org.apache.druid.indexing.common.actions.TaskAction; import org.apache.druid.indexing.common.task.Tasks; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; @@ -58,6 +59,7 @@ import org.joda.time.Interval; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; @@ -1650,7 +1652,12 @@ public void testEmptyReplaceAllOverEternitySegment(String contextName, Map>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()) + )); // Insert with a condition which results in 0 rows being inserted -- do nothing. testIngestQuery().setSql( @@ -1683,7 +1690,10 @@ public void testEmptyReplaceAllWithAllGrainOverFiniteIntervalSegment(String cont .build(); Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); // Insert with a condition which results in 0 rows being inserted -- do nothing. testIngestQuery().setSql( @@ -1716,7 +1726,10 @@ public void testEmptyReplaceAllWithAllGrainOverEternitySegment(String contextNam Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); // Insert with a condition which results in 0 rows being inserted -- do nothing. testIngestQuery().setSql( @@ -1800,7 +1813,10 @@ public void testEmptyReplaceIntervalOverEternitySegment(String contextName, Map< Mockito.doReturn(ImmutableSet.of(existingDataSegment)) .when(testTaskActionClient) - .submit(ArgumentMatchers.isA(RetrieveUsedSegmentsAction.class)); + .submit(ArgumentMatchers.argThat( + (ArgumentMatcher>) argument -> + argument instanceof RetrieveUsedSegmentsAction + && "foo1".equals(((RetrieveUsedSegmentsAction) argument).getDataSource()))); // Insert with a condition which results in 0 rows being inserted -- do nothing! testIngestQuery().setSql( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index 7c4af7389f6c..56f1ce986965 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -2220,10 +2220,7 @@ public void testSelectRowsGetUntruncatedByDefault(String contextName, Map context) { - - - - // This test asserts that the join algorithnm used is a different one from that supplied. In sqlCompatible() mode + // This test asserts that the join algorithm used is a different one from that supplied. In sqlCompatible() mode // the query gets planned differently, therefore we do use the sortMerge processor. Instead of having separate // handling, a similar test has been described in CalciteJoinQueryMSQTest, therefore we don't want to repeat that // here, hence ignoring in sqlCompatible() mode diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java index 73a443db8a25..904510408ff0 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java @@ -20,10 +20,19 @@ package org.apache.druid.msq.exec; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.client.indexing.NoopOverlordClient; +import org.apache.druid.client.indexing.TaskStatusResponse; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.indexer.RunnerTaskState; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatus; +import org.apache.druid.indexer.TaskStatusPlus; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; @@ -37,6 +46,7 @@ import org.apache.druid.msq.indexing.error.TooManyWorkersFault; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; +import org.apache.druid.utils.CollectionUtils; import org.junit.Assert; import org.junit.Test; @@ -47,8 +57,6 @@ import java.util.concurrent.TimeUnit; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class MSQTasksTest { @@ -214,12 +222,10 @@ public void test_queryWithoutEnoughSlots_shouldThrowException() final int numSlots = 5; final int numTasks = 10; - ControllerContext controllerContext = mock(ControllerContext.class); - when(controllerContext.workerManager()).thenReturn(new TasksTestWorkerManagerClient(numSlots)); MSQWorkerTaskLauncher msqWorkerTaskLauncher = new MSQWorkerTaskLauncher( CONTROLLER_ID, "foo", - controllerContext, + new TasksTestOverlordClient(numSlots), (task, fault) -> {}, ImmutableMap.of(), TimeUnit.SECONDS.toMillis(5) @@ -227,7 +233,7 @@ public void test_queryWithoutEnoughSlots_shouldThrowException() try { msqWorkerTaskLauncher.start(); - msqWorkerTaskLauncher.launchTasksIfNeeded(numTasks); + msqWorkerTaskLauncher.launchWorkersIfNeeded(numTasks); fail(); } catch (Exception e) { @@ -238,7 +244,7 @@ public void test_queryWithoutEnoughSlots_shouldThrowException() } } - static class TasksTestWorkerManagerClient implements WorkerManagerClient + static class TasksTestOverlordClient extends NoopOverlordClient { // Num of slots available for tasks final int numSlots; @@ -252,13 +258,13 @@ static class TasksTestWorkerManagerClient implements WorkerManagerClient @GuardedBy("this") final Set canceledTasks = new HashSet<>(); - public TasksTestWorkerManagerClient(final int numSlots) + public TasksTestOverlordClient(final int numSlots) { this.numSlots = numSlots; } @Override - public synchronized Map statuses(final Set taskIds) + public synchronized ListenableFuture> taskStatuses(final Set taskIds) { final Map retVal = new HashMap<>(); @@ -277,42 +283,66 @@ public synchronized Map statuses(final Set taskIds) } } - return retVal; + return Futures.immediateFuture(retVal); } @Override - public synchronized TaskLocation location(String workerId) + public synchronized ListenableFuture taskStatus(String workerId) { + final TaskStatus status = CollectionUtils.getOnlyElement( + FutureUtils.getUnchecked(taskStatuses(ImmutableSet.of(workerId)), true).values(), + xs -> new ISE("Expected one worker with id[%s] but saw[%s]", workerId, xs) + ); + + final TaskLocation location; + if (runningTasks.contains(workerId)) { - return TaskLocation.create("host-" + workerId, 1, -1); + location = TaskLocation.create("host-" + workerId, 1, -1); } else { - return TaskLocation.unknown(); + location = TaskLocation.unknown(); } + + return Futures.immediateFuture( + new TaskStatusResponse( + status.getId(), + new TaskStatusPlus( + status.getId(), + null, + null, + DateTimes.utc(0), + DateTimes.utc(0), + status.getStatusCode(), + status.getStatusCode(), + RunnerTaskState.NONE, + status.getDuration(), + location, + null, + status.getErrorMsg() + ) + ) + ); } @Override - public synchronized String run(String taskId, MSQWorkerTask task) + public synchronized ListenableFuture runTask(String taskId, Object taskObject) { + final MSQWorkerTask task = (MSQWorkerTask) taskObject; + allTasks.add(task.getId()); if (runningTasks.size() < numSlots) { runningTasks.add(task.getId()); } - return task.getId(); + return Futures.immediateFuture(null); } @Override - public synchronized void cancel(String workerId) + public synchronized ListenableFuture cancelTask(String workerId) { runningTasks.remove(workerId); canceledTasks.add(workerId); - } - - @Override - public void close() - { - // do nothing + return Futures.immediateFuture(null); } } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java index d9cbb48d986c..d7364124483a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/QueryValidatorTest.java @@ -106,6 +106,8 @@ public void testMoreInputFiles() 0, 0, Collections.singletonList(() -> inputFiles), // Slice with a large number of inputFiles + null, + null, null ); @@ -125,8 +127,7 @@ public void testMoreInputFiles() private static QueryDefinition createQueryDefinition(int numColumns, int numWorkers) { - QueryDefinitionBuilder builder = QueryDefinition.builder(); - builder.queryId(UUID.randomUUID().toString()); + QueryDefinitionBuilder builder = QueryDefinition.builder(UUID.randomUUID().toString()); StageDefinitionBuilder stageBuilder = StageDefinition.builder(0); builder.add(stageBuilder); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index 592fd089ef4e..cba8ede156ce 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -24,7 +24,6 @@ import com.google.common.collect.ImmutableSortedMap; import com.google.common.util.concurrent.Futures; import org.apache.druid.java.util.common.ISE; -import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; @@ -44,7 +43,6 @@ import static org.easymock.EasyMock.mock; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; @@ -56,7 +54,7 @@ public class WorkerSketchFetcherTest private CompleteKeyStatisticsInformation completeKeyStatisticsInformation; @Mock - private MSQWorkerTaskLauncher workerTaskLauncher; + private WorkerManager workerManager; @Mock private ControllerQueryKernel kernel; @@ -82,7 +80,10 @@ public void setUp() doReturn(ImmutableSortedMap.of(123L, ImmutableSet.of(1, 2))).when(completeKeyStatisticsInformation) .getTimeSegmentVsWorkerMap(); - doReturn(true).when(workerTaskLauncher).isTaskLatest(any()); + doReturn(0).when(workerManager).getWorkerNumber(TASK_0); + doReturn(1).when(workerManager).getWorkerNumber(TASK_1); + doReturn(2).when(workerManager).getWorkerNumber(TASK_2); + doReturn(true).when(workerManager).isWorkerActive(any()); } @After @@ -100,13 +101,13 @@ public void test_submitFetcherTask_parallelFetch() throws InterruptedException final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); // When fetching snapshots, return a mock and add it to queue doAnswer(invocation -> { ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class); return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); + }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any()); target.inMemoryFullSketchMerging((kernelConsumer) -> { kernelConsumer.accept(kernel); @@ -123,13 +124,13 @@ public void test_submitFetcherTask_sequentialFetch() throws InterruptedException doReturn(true).when(completeKeyStatisticsInformation).isComplete(); final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); // When fetching snapshots, return a mock and add it to queue doAnswer(invocation -> { ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class); return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong()); + }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong()); target.sequentialTimeChunkMerging( (kernelConsumer) -> { @@ -151,7 +152,7 @@ public void test_sequentialMerge_nonCompleteInformation() { doReturn(false).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); Assert.assertThrows(ISE.class, () -> target.sequentialTimeChunkMerging( (ignore) -> {}, completeKeyStatisticsInformation, @@ -166,7 +167,7 @@ public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1)); @@ -185,8 +186,8 @@ public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException }) ); - Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); - Assert.assertTrue(retryLatch.await(5, TimeUnit.SECONDS)); + Assert.assertTrue(latch.await(500, TimeUnit.SECONDS)); + Assert.assertTrue(retryLatch.await(500, TimeUnit.SECONDS)); } @Test @@ -195,7 +196,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti doReturn(true).when(completeKeyStatisticsInformation).isComplete(); final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1)); CountDownLatch retryLatch = new CountDownLatch(1); @@ -222,7 +223,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedException { - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1, TASK_0)); @@ -251,7 +252,7 @@ public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedExce public void test_InMemoryRetryDisabled_singleFailure() throws InterruptedException { - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1)); @@ -282,7 +283,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx { doReturn(true).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1, TASK_0)); @@ -314,7 +315,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx public void test_SequentialRetryDisabled_singleFailure() throws InterruptedException { doReturn(true).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerTaskLauncher, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1)); @@ -351,7 +352,7 @@ private void workersWithFailedFetchSequential(Set failedTasks) return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0))); } return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyInt(), anyLong()); + }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong()); } private void workersWithFailedFetchParallel(Set failedTasks) @@ -362,7 +363,7 @@ private void workersWithFailedFetchParallel(Set failedTasks) return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0))); } return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt()); + }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any()); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java index bc3f24065aea..25ab33f76f94 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java @@ -21,7 +21,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.rpc.indexing.OverlordClient; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -40,7 +40,7 @@ public void setUp() target = new MSQWorkerTaskLauncher( "controller-id", "foo", - Mockito.mock(ControllerContext.class), + Mockito.mock(OverlordClient.class), (task, fault) -> {}, ImmutableMap.of(), TimeUnit.SECONDS.toMillis(5) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java new file mode 100644 index 000000000000..11c33d215170 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/TaskReportQueryListenerTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.indexer.report.TaskContextReport; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStagesReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; +import org.apache.druid.msq.indexing.report.MSQTaskReportTest; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnType; +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class TaskReportQueryListenerTest +{ + private static final String TASK_ID = "mytask"; + private static final Map TASK_CONTEXT = ImmutableMap.of("foo", "bar"); + private static final List SIGNATURE = ImmutableList.of( + new MSQResultsReport.ColumnAndType("x", ColumnType.STRING) + ); + private static final List SQL_TYPE_NAMES = ImmutableList.of(SqlTypeName.VARCHAR); + private static final ObjectMapper JSON_MAPPER = + TestHelper.makeJsonMapper().registerModules(new MSQIndexingModule().getJacksonModules()); + + private final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + @Test + public void test_taskReportDestination() throws IOException + { + final TaskReportQueryListener listener = new TaskReportQueryListener( + TaskReportMSQDestination.instance(), + Suppliers.ofInstance(baos)::get, + JSON_MAPPER, + TASK_ID, + TASK_CONTEXT + ); + + Assert.assertTrue(listener.readResults()); + listener.onResultsStart(SIGNATURE, SQL_TYPE_NAMES); + Assert.assertTrue(listener.onResultRow(new Object[]{"foo"})); + Assert.assertTrue(listener.onResultRow(new Object[]{"bar"})); + listener.onResultsComplete(); + listener.onQueryComplete( + new MSQTaskReportPayload( + new MSQStatusReport( + TaskState.SUCCESS, + null, + Collections.emptyList(), + null, + 0, + new HashMap<>(), + 1, + 2, + null, + null + ), + MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of() + ), + new CounterSnapshotsTree(), + null + ) + ); + + final TaskReport.ReportMap reportMap = + JSON_MAPPER.readValue( + baos.toByteArray(), + new TypeReference() {} + ); + + Assert.assertEquals(ImmutableSet.of("multiStageQuery", TaskContextReport.REPORT_KEY), reportMap.keySet()); + Assert.assertEquals(TASK_CONTEXT, ((TaskContextReport) reportMap.get(TaskContextReport.REPORT_KEY)).getPayload()); + + final MSQTaskReport report = (MSQTaskReport) reportMap.get("multiStageQuery"); + final List> results = + report.getPayload().getResults().getResults().stream().map(Arrays::asList).collect(Collectors.toList()); + + Assert.assertEquals( + ImmutableList.of( + ImmutableList.of("foo"), + ImmutableList.of("bar") + ), + results + ); + + Assert.assertFalse(report.getPayload().getResults().isResultsTruncated()); + Assert.assertEquals(TaskState.SUCCESS, report.getPayload().getStatus().getStatus()); + } + + @Test + public void test_durableDestination() throws IOException + { + final TaskReportQueryListener listener = new TaskReportQueryListener( + DurableStorageMSQDestination.instance(), + Suppliers.ofInstance(baos)::get, + JSON_MAPPER, + TASK_ID, + TASK_CONTEXT + ); + + Assert.assertTrue(listener.readResults()); + listener.onResultsStart(SIGNATURE, SQL_TYPE_NAMES); + for (int i = 0; i < Limits.MAX_SELECT_RESULT_ROWS - 1; i++) { + Assert.assertTrue("row #" + i, listener.onResultRow(new Object[]{"foo"})); + } + Assert.assertFalse(listener.onResultRow(new Object[]{"foo"})); + listener.onQueryComplete( + new MSQTaskReportPayload( + new MSQStatusReport( + TaskState.SUCCESS, + null, + Collections.emptyList(), + null, + 0, + new HashMap<>(), + 1, + 2, + null, + null + ), + MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of() + ), + new CounterSnapshotsTree(), + null + ) + ); + + final TaskReport.ReportMap reportMap = + JSON_MAPPER.readValue( + baos.toByteArray(), + new TypeReference() {} + ); + + Assert.assertEquals(ImmutableSet.of("multiStageQuery", TaskContextReport.REPORT_KEY), reportMap.keySet()); + Assert.assertEquals(TASK_CONTEXT, ((TaskContextReport) reportMap.get(TaskContextReport.REPORT_KEY)).getPayload()); + + final MSQTaskReport report = (MSQTaskReport) reportMap.get("multiStageQuery"); + final List> results = + report.getPayload().getResults().getResults().stream().map(Arrays::asList).collect(Collectors.toList()); + + Assert.assertEquals( + IntStream.range(0, (int) Limits.MAX_SELECT_RESULT_ROWS) + .mapToObj(i -> ImmutableList.of("foo")) + .collect(Collectors.toList()), + results + ); + + Assert.assertTrue(report.getPayload().getResults().isResultsTruncated()); + Assert.assertEquals(TaskState.SUCCESS, report.getPayload().getStatus().getStatus()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java index 9fe32cc8c8c1..3fd346f4db42 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java @@ -22,9 +22,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.indexer.TaskStatus; -import org.apache.druid.indexer.report.TaskReport; -import org.apache.druid.indexer.report.TaskReportFileWriter; import org.apache.druid.indexing.common.TaskToolbox; +import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.counters.CounterSnapshotsTree; @@ -83,22 +82,7 @@ public void setUp() toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER) .indexIO(indexIO) .indexMergerV9(indexMerger) - .taskReportFileWriter( - new TaskReportFileWriter() - { - @Override - public void write(String taskId, TaskReport.ReportMap reports) - { - - } - - @Override - public void setObjectMapper(ObjectMapper objectMapper) - { - - } - } - ) + .taskReportFileWriter(new NoopTestTaskReportFileWriter()) .build(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java index 10a724f4b7ed..93436c84eadd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/ControllerChatHandlerTest.java @@ -21,9 +21,7 @@ import org.apache.druid.indexer.report.KillTaskReport; import org.apache.druid.indexer.report.TaskReport; -import org.apache.druid.indexing.common.TaskToolbox; import org.apache.druid.msq.exec.Controller; -import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.server.security.AuthorizerMapper; import org.junit.Assert; import org.junit.Test; @@ -35,6 +33,8 @@ public class ControllerChatHandlerTest { + private static final String DATASOURCE = "wiki"; + @Test public void testHttpGetLiveReports() { @@ -46,17 +46,8 @@ public void testHttpGetLiveReports() Mockito.when(controller.liveReports()) .thenReturn(reportMap); - MSQControllerTask task = Mockito.mock(MSQControllerTask.class); - Mockito.when(task.getDataSource()) - .thenReturn("wiki"); - Mockito.when(controller.task()) - .thenReturn(task); - - TaskToolbox toolbox = Mockito.mock(TaskToolbox.class); - Mockito.when(toolbox.getAuthorizerMapper()) - .thenReturn(new AuthorizerMapper(null)); - - ControllerChatHandler chatHandler = new ControllerChatHandler(toolbox, controller); + final AuthorizerMapper authorizerMapper = new AuthorizerMapper(null); + ControllerChatHandler chatHandler = new ControllerChatHandler(controller, DATASOURCE, authorizerMapper); HttpServletRequest httpRequest = Mockito.mock(HttpServletRequest.class); Mockito.when(httpRequest.getAttribute(ArgumentMatchers.anyString())) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java index 158f65a05940..4ab992aec096 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java @@ -30,9 +30,6 @@ import org.apache.druid.indexer.report.SingleFileTaskReportFileWriter; import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.java.util.common.DateTimes; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.exec.SegmentLoadStatusFetcher; import org.apache.druid.msq.guice.MSQIndexingModule; @@ -52,10 +49,10 @@ import java.io.File; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.UUID; public class MSQTaskReportTest { @@ -63,7 +60,7 @@ public class MSQTaskReportTest private static final String HOST = "example.com:1234"; public static final QueryDefinition QUERY_DEFINITION = QueryDefinition - .builder() + .builder(UUID.randomUUID().toString()) .add( StageDefinition .builder(0) @@ -112,13 +109,14 @@ public void testSerdeResultsReport() throws Exception ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), new MSQResultsReport( Collections.singletonList(new MSQResultsReport.ColumnAndType("s", ColumnType.STRING)), ImmutableList.of(SqlTypeName.VARCHAR), - Yielders.each(Sequences.simple(results)), + results, null ) ) @@ -139,13 +137,7 @@ public void testSerdeResultsReport() throws Exception Assert.assertEquals(report.getPayload().getStatus().getPendingTasks(), report2.getPayload().getStatus().getPendingTasks()); Assert.assertEquals(report.getPayload().getStages(), report2.getPayload().getStages()); - Yielder yielder = report2.getPayload().getResults().getResultYielder(); - final List results2 = new ArrayList<>(); - - while (!yielder.isDone()) { - results2.add(yielder.get()); - yielder = yielder.next(null); - } + final List results2 = report2.getPayload().getResults().getResults(); Assert.assertEquals(results.size(), results2.size()); for (int i = 0; i < results.size(); i++) { Assert.assertArrayEquals(results.get(i), results2.get(i)); @@ -177,6 +169,7 @@ public void testSerdeErrorReport() throws Exception ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), @@ -225,6 +218,7 @@ public void testWriteTaskReport() throws Exception ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java index 3b2705c8ba6c..d550fef84c77 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSliceTest.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.segment.TestHelper; @@ -37,7 +38,8 @@ public void testSerde() throws Exception final StageInputSlice slice = new StageInputSlice( 2, - ReadablePartitions.striped(2, 3, 4) + ReadablePartitions.striped(2, 3, 4), + OutputChannelMode.MEMORY ); Assert.assertEquals( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java index 43d89e7fc690..024ad956cf29 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StageInputSpecSlicerTest.java @@ -24,6 +24,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import org.apache.druid.msq.exec.OutputChannelMode; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.junit.Assert; @@ -43,12 +44,21 @@ public class StageInputSpecSlicerTest .build() ); + private static final Int2ObjectMap STAGE_OUTPUT_MODE_MAP = + new Int2ObjectOpenHashMap<>( + ImmutableMap.builder() + .put(0, OutputChannelMode.LOCAL_STORAGE) + .put(1, OutputChannelMode.LOCAL_STORAGE) + .put(2, OutputChannelMode.LOCAL_STORAGE) + .build() + ); + private StageInputSpecSlicer slicer; @Before public void setUp() { - slicer = new StageInputSpecSlicer(STAGE_PARTITIONS_MAP); + slicer = new StageInputSpecSlicer(STAGE_PARTITIONS_MAP, STAGE_OUTPUT_MODE_MAP); } @Test @@ -64,7 +74,8 @@ public void test_sliceStatic_stageZeroOneSlice() Collections.singletonList( new StageInputSlice( 0, - ReadablePartitions.striped(0, 2, 2) + ReadablePartitions.striped(0, 2, 2), + OutputChannelMode.LOCAL_STORAGE ) ), slicer.sliceStatic(new StageInputSpec(0), 1) @@ -78,11 +89,13 @@ public void test_sliceStatic_stageZeroTwoSlices() ImmutableList.of( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0})), + OutputChannelMode.LOCAL_STORAGE ), new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})), + OutputChannelMode.LOCAL_STORAGE ) ), slicer.sliceStatic(new StageInputSpec(0), 2) @@ -96,11 +109,13 @@ public void test_sliceStatic_stageOneTwoSlices() ImmutableList.of( new StageInputSlice( 1, - new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{0, 2})) + new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{0, 2})), + OutputChannelMode.LOCAL_STORAGE ), new StageInputSlice( 1, - new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{1, 3})) + new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{1, 3})), + OutputChannelMode.LOCAL_STORAGE ) ), slicer.sliceStatic(new StageInputSpec(1), 2) @@ -115,6 +130,6 @@ public void test_sliceStatic_notAvailable() () -> slicer.sliceStatic(new StageInputSpec(3), 1) ); - MatcherAssert.assertThat(e.getMessage(), CoreMatchers.equalTo("Stage [3] not available")); + MatcherAssert.assertThat(e.getMessage(), CoreMatchers.equalTo("Stage[3] output partitions not available")); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java index 634427d01a9b..a27ae7d97804 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/table/TableInputSpecSlicerTest.java @@ -19,15 +19,20 @@ package org.apache.druid.msq.input.table; +import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; import org.apache.druid.data.input.StringTuple; +import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction; +import org.apache.druid.indexing.common.actions.TaskAction; +import org.apache.druid.indexing.common.actions.TaskActionClient; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.msq.exec.SegmentSource; import org.apache.druid.msq.input.NilInputSlice; -import org.apache.druid.msq.querykit.DataSegmentTimelineView; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.SegmentTimeline; +import org.apache.druid.timeline.VersionedIntervalTimeline; import org.apache.druid.timeline.partition.DimensionRangeShardSpec; import org.apache.druid.timeline.partition.TombstoneShardSpec; import org.junit.Assert; @@ -35,7 +40,6 @@ import org.junit.Test; import java.util.Collections; -import java.util.Optional; public class TableInputSpecSlicerTest extends InitializedNullHandlingTest { @@ -94,19 +98,44 @@ public class TableInputSpecSlicerTest extends InitializedNullHandlingTest ); private SegmentTimeline timeline; private TableInputSpecSlicer slicer; + private TaskActionClient taskActionClient; @Before public void setUp() { timeline = SegmentTimeline.forSegments(ImmutableList.of(SEGMENT1, SEGMENT2, SEGMENT3)); - DataSegmentTimelineView timelineView = (dataSource, intervals) -> { - if (DATASOURCE.equals(dataSource)) { - return Optional.of(timeline); - } else { - return Optional.empty(); + taskActionClient = new TaskActionClient() + { + @Override + @SuppressWarnings("unchecked") + public RetType submit(TaskAction taskAction) + { + if (taskAction instanceof RetrieveUsedSegmentsAction) { + final RetrieveUsedSegmentsAction retrieveUsedSegmentsAction = (RetrieveUsedSegmentsAction) taskAction; + final String dataSource = retrieveUsedSegmentsAction.getDataSource(); + + if (DATASOURCE.equals(dataSource)) { + return (RetType) FluentIterable + .from(retrieveUsedSegmentsAction.getIntervals()) + .transformAndConcat( + interval -> + VersionedIntervalTimeline.getAllObjects(timeline.lookup(interval)) + ) + .toList(); + } else { + return (RetType) Collections.emptyList(); + } + } + + throw new UnsupportedOperationException(); } }; - slicer = new TableInputSpecSlicer(timelineView); + + slicer = new TableInputSpecSlicer( + null /* not used for SegmentSource.NONE */, + taskActionClient, + SegmentSource.NONE + ); } @Test diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java index 8a5533d22cb0..857584127a9a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/QueryDefinitionTest.java @@ -33,6 +33,8 @@ import org.junit.Assert; import org.junit.Test; +import java.util.UUID; + public class QueryDefinitionTest { @Test @@ -40,7 +42,7 @@ public void testSerde() throws Exception { final QueryDefinition queryDef = QueryDefinition - .builder() + .builder(UUID.randomUUID().toString()) .add( StageDefinition .builder(0) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java index 6ae18dda1e1d..2365b5cf86bc 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/BaseControllerQueryKernelTest.java @@ -19,6 +19,7 @@ package org.apache.druid.msq.kernel.controller; +import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -28,6 +29,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; import org.apache.druid.msq.indexing.error.MSQFault; import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.input.InputSpecSlicerFactory; @@ -47,15 +49,28 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.IntStream; public class BaseControllerQueryKernelTest extends InitializedNullHandlingTest { public static final UnknownFault RETRIABLE_FAULT = UnknownFault.forMessage(""); - public ControllerQueryKernelTester testControllerQueryKernel(int numWorkers) + public ControllerQueryKernelTester testControllerQueryKernel() { - return new ControllerQueryKernelTester(numWorkers); + return testControllerQueryKernel(ControllerQueryKernelConfig.Builder::build); + } + + public ControllerQueryKernelTester testControllerQueryKernel( + final Function configFn + ) + { + return new ControllerQueryKernelTester( + configFn.apply( + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(100_000_000) + .destination(TaskReportMSQDestination.instance()) + ) + ); } /** @@ -69,34 +84,29 @@ public static class ControllerQueryKernelTester private boolean initialized = false; private QueryDefinition queryDefinition = null; private ControllerQueryKernel controllerQueryKernel = null; - private InputSpecSlicerFactory inputSlicerFactory = - stagePartitionsMap -> + private final InputSpecSlicerFactory inputSlicerFactory = + (stagePartitionsMap, stageOutputChannelModeMap) -> new MapInputSpecSlicer( ImmutableMap.of( - StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap), + StageInputSpec.class, new StageInputSpecSlicer(stagePartitionsMap, stageOutputChannelModeMap), ControllerTestInputSpec.class, new ControllerTestInputSpecSlicer() ) ); - private final int numWorkers; + private final ControllerQueryKernelConfig config; Set setupStages = new HashSet<>(); - private ControllerQueryKernelTester(int numWorkers) + private ControllerQueryKernelTester(ControllerQueryKernelConfig config) { - this.numWorkers = numWorkers; + this.config = config; } public ControllerQueryKernelTester queryDefinition(QueryDefinition queryDefinition) { this.queryDefinition = Preconditions.checkNotNull(queryDefinition); - this.controllerQueryKernel = new ControllerQueryKernel( - queryDefinition, - 100_000_000, - true - ); + this.controllerQueryKernel = new ControllerQueryKernel(queryDefinition, config); return this; } - public ControllerQueryKernelTester setupStage( int stageNumber, ControllerStagePhase controllerStagePhase @@ -275,11 +285,17 @@ public void startWorkOrder(int stageNumber) { StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber); Preconditions.checkArgument(initialized); - IntStream.range(0, queryDefinition.getStageDefinition(stageId).getMaxWorkerCount()) - .forEach(n -> controllerQueryKernel.workOrdersSentForWorker(stageId, n)); - + controllerQueryKernel.getWorkerInputsForStage(stageId).workers() + .forEach(n -> controllerQueryKernel.workOrdersSentForWorker(stageId, n)); } + public void doneReadingInput(int stageNumber) + { + StageId stageId = new StageId(queryDefinition.getQueryId(), stageNumber); + Preconditions.checkArgument(initialized); + controllerQueryKernel.getWorkerInputsForStage(stageId).workers() + .forEach(n -> controllerQueryKernel.setDoneReadingInputForStageAndWorker(stageId, n)); + } public void finishStage(int stageNumber) { @@ -353,22 +369,27 @@ public void failStage(int stageNumber) public void assertStagePhase(int stageNumber, ControllerStagePhase expectedControllerStagePhase) { Preconditions.checkArgument(initialized); - ControllerStageTracker controllerStageTracker = Preconditions.checkNotNull( - controllerQueryKernel.getControllerStageKernel(stageNumber), + ControllerStageTracker controllerStageKernel = Preconditions.checkNotNull( + controllerQueryKernel.getControllerStageTracker(stageNumber), StringUtils.format("Stage kernel for stage number %d is not initialized yet", stageNumber) ); - if (controllerStageTracker.getPhase() != expectedControllerStagePhase) { + if (controllerStageKernel.getPhase() != expectedControllerStagePhase) { throw new ISE( StringUtils.format( "Stage kernel for stage number %d is in %s phase which is different from the expected phase %s", stageNumber, - controllerStageTracker.getPhase(), + controllerStageKernel.getPhase(), expectedControllerStagePhase ) ); } } + public ControllerQueryKernelConfig getConfig() + { + return config; + } + /** * Checks if the state of the BaseControllerQueryKernel is initialized properly. Currently, this is just stubbed to * return true irrespective of the actual state diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java index 8e47c470bf82..03f963b133bf 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelTest.java @@ -19,9 +19,13 @@ package org.apache.druid.msq.kernel.controller; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.ShuffleKind; import org.apache.druid.msq.kernel.worker.WorkerStagePhase; import org.junit.Assert; import org.junit.Test; @@ -34,7 +38,7 @@ public class ControllerQueryKernelTest extends BaseControllerQueryKernelTest @Test public void testCompleteDAGExecutionForSingleWorker() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 1 // | / | // 2 / 3 @@ -44,13 +48,13 @@ public void testCompleteDAGExecutionForSingleWorker() // 6 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(7) - .addVertex(0, 2) - .addVertex(1, 2) - .addVertex(1, 3) - .addVertex(2, 4) - .addVertex(3, 5) - .addVertex(4, 6) - .addVertex(5, 6) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(1, 3) + .addEdge(2, 4) + .addEdge(3, 5) + .addEdge(4, 6) + .addEdge(5, 6) .getQueryDefinitionBuilder() .build() ); @@ -62,79 +66,196 @@ public void testCompleteDAGExecutionForSingleWorker() newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 1), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(0), newStageNumbers); Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + // Mark 0 as done. Next up will be 1. + transitionNewToResultsComplete(controllerQueryKernelTester, 0); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(1), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + // Mark 1 as done and fetch the new kernels. Next up will be 2. transitionNewToResultsComplete(controllerQueryKernelTester, 1); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 3), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(2), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + // Mark 2 as done and fetch the new kernels. Next up will be 3. + transitionNewToResultsComplete(controllerQueryKernelTester, 2); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0), effectivelyFinishedStageNumbers); - // Mark 3 as done and fetch the new kernels. 5 should be unblocked along with 0. + // Mark 3 as done and fetch the new kernels. Next up will be 4. transitionNewToResultsComplete(controllerQueryKernelTester, 3); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 5), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(4), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(0, 1), effectivelyFinishedStageNumbers); + // Mark 4 as done and fetch new kernels. Next up will be 5. + transitionNewToResultsComplete(controllerQueryKernelTester, 4); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(5), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0, 1, 2), effectivelyFinishedStageNumbers); + + // Mark 0, 1, 2 finished together. + effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); - // Mark 5 as done and fetch the new kernels. Only 0 is still unblocked, but 3 can now be cleaned + // Mark 5 as done and fetch new kernels. Next up will be 6, and 3 will be ready to finish. transitionNewToResultsComplete(controllerQueryKernelTester, 5); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(6), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); - // Mark 0 as done and fetch the new kernels. This should unblock 2 + // Mark 6 as done. No more kernels left, but we can clean up 4, 5, 6 along with 3. + transitionNewToResultsComplete(controllerQueryKernelTester, 6); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3, 4, 5, 6), effectivelyFinishedStageNumbers); + effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); + } + + @Test + public void testCompleteDAGExecutionForSingleWorkerWithPipelining() + { + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder.maxConcurrentStages(2).pipeline(true).build() + ); + // 0 [HLS] 1 [HLS] + // | / | + // 2 [none] 3 [HLS] + // | | + // 4 [mix] 5 [HLS] + // \ / + // \ / + // 6 [none] + + final QueryDefinition queryDef = new MockQueryDefinitionBuilder(7) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(1, 3) + .addEdge(2, 4) + .addEdge(3, 5) + .addEdge(4, 6) + .addEdge(5, 6) + .defineStage(0, ShuffleKind.HASH_LOCAL_SORT) + .defineStage(1, ShuffleKind.HASH_LOCAL_SORT) + .defineStage(3, ShuffleKind.HASH_LOCAL_SORT) + .defineStage(4, ShuffleKind.MIX) + .defineStage(5, ShuffleKind.HASH_LOCAL_SORT) + .getQueryDefinitionBuilder() + .build(); + + controllerQueryKernelTester.queryDefinition(queryDef); + controllerQueryKernelTester.init(); + + Assert.assertEquals( + ImmutableList.of( + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2, 4), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 3), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 5), + ControllerQueryKernelUtilsTest.makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 6) + ), + ControllerQueryKernelUtils.computeStageGroups(queryDef, controllerQueryKernelTester.getConfig()) + ); + + Set newStageNumbers; + Set effectivelyFinishedStageNumbers; + + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0, 1), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + + + transitionNewToResultsComplete(controllerQueryKernelTester, 1); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + + + // Mark 0 as done and fetch the new kernels. 2 should be unblocked along with 4. transitionNewToResultsComplete(controllerQueryKernelTester, 0); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(2), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(2, 4), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(), effectivelyFinishedStageNumbers); + - // Mark 2 as done and fetch new kernels. This should clear up 0 and 1 alongside 3 (which is not marked as FINISHED yet) + // Mark 2 as done and fetch the new kernels. 4 is still ready, 0 can now be cleaned, and 3 can be launched transitionNewToResultsComplete(controllerQueryKernelTester, 2); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); - Assert.assertEquals(ImmutableSet.of(4), newStageNumbers); + Assert.assertEquals(ImmutableSet.of(3, 4), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0), effectivelyFinishedStageNumbers); + + // Mark 4 as done and fetch the new kernels. 3 is still ready, and 2 becomes cleanable + transitionNewToResultsComplete(controllerQueryKernelTester, 4); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(0, 1, 3), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(0, 2), effectivelyFinishedStageNumbers); - // Mark 0, 1, 3 finished together + // Mark 3 as post-reading and fetch new kernels. This makes 1 cleanable, and 5 ready to run + transitionNewToDoneReadingInput(controllerQueryKernelTester, 3); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(5), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(0, 1, 2), effectivelyFinishedStageNumbers); + + // Mark 0, 1, 2 finished together effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); - // Mark 4 as done and fetch new kernels. This should unblock 6 and clear up 2 - transitionNewToResultsComplete(controllerQueryKernelTester, 4); + // Mark 5 as post-reading and fetch new kernels. Nothing is ready, since 6 is waiting for 5 to finish + // However, this does clear up 3 to become cleanable + transitionNewToDoneReadingInput(controllerQueryKernelTester, 5); + newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); + Assert.assertEquals(ImmutableSet.of(), newStageNumbers); + effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); + Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); + + // Mark 5 as done. This makes 6 ready to go + transitionDoneReadingInputToResultsComplete(controllerQueryKernelTester, 5); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); Assert.assertEquals(ImmutableSet.of(6), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(2), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(3), effectivelyFinishedStageNumbers); - // Mark 6 as done. No more kernels left, but we can clean up 4 and 5 alongwith 2 + // Mark 6 as done. No more kernels left, but we can clean up 4 and 5 along with 2 transitionNewToResultsComplete(controllerQueryKernelTester, 6); newStageNumbers = controllerQueryKernelTester.createAndGetNewStageNumbers(); Assert.assertEquals(ImmutableSet.of(), newStageNumbers); effectivelyFinishedStageNumbers = controllerQueryKernelTester.getEffectivelyFinishedStageNumbers(); - Assert.assertEquals(ImmutableSet.of(2, 4, 5), effectivelyFinishedStageNumbers); + Assert.assertEquals(ImmutableSet.of(3, 4, 5, 6), effectivelyFinishedStageNumbers); effectivelyFinishedStageNumbers.forEach(controllerQueryKernelTester::finishStage); } @Test public void testCompleteDAGExecutionForMultipleWorkers() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(2); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 -> 1 -> 2 -> 3 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(4) - .addVertex(0, 1) - .addVertex(1, 2) - .addVertex(2, 3) - .defineStage(0, true, 1) // Ingestion only on one worker - .defineStage(1, true, 2) - .defineStage(3, true, 2) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 3) + .defineStage(0, ShuffleKind.GLOBAL_SORT, 1) // Ingestion only on one worker + .defineStage(1, ShuffleKind.GLOBAL_SORT, 2) + .defineStage(3, ShuffleKind.GLOBAL_SORT, 2) .getQueryDefinitionBuilder() .build() ); @@ -233,12 +354,12 @@ public void testCompleteDAGExecutionForMultipleWorkers() @Test public void testTransitionsInShufflingStagesAndMultipleWorkers() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(2); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // Single stage query definition controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(1) - .defineStage(0, true, 2) + .defineStage(0, ShuffleKind.GLOBAL_SORT, 2) .getQueryDefinitionBuilder() .build() ); @@ -275,12 +396,12 @@ public void testTransitionsInShufflingStagesAndMultipleWorkers() @Test public void testPrematureResultsComplete() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(2); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // Single stage query definition controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(1) - .defineStage(0, true, 2) + .defineStage(0, ShuffleKind.GLOBAL_SORT, 2) .getQueryDefinitionBuilder() .build() ); @@ -311,15 +432,18 @@ public void testPrematureResultsComplete() @Test public void testKernelFailed() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder.maxConcurrentStages(2).build() + ); // 0 1 // \ / // 2 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(3) - .addVertex(0, 2) - .addVertex(1, 2) + .addEdge(0, 2) + .addEdge(1, 2) .getQueryDefinitionBuilder() .build() ); @@ -340,16 +464,16 @@ public void testKernelFailed() @Test(expected = IllegalStateException.class) public void testCycleInvalidQueryThrowsException() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 - 1 // \ / // 2 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(3) - .addVertex(0, 1) - .addVertex(1, 2) - .addVertex(2, 0) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 0) .getQueryDefinitionBuilder() .build() ); @@ -358,13 +482,13 @@ public void testCycleInvalidQueryThrowsException() @Test(expected = IllegalStateException.class) public void testSelfLoopInvalidQueryThrowsException() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 _ // |__| controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(1) - .addVertex(0, 0) + .addEdge(0, 0) .getQueryDefinitionBuilder() .build() ); @@ -373,15 +497,15 @@ public void testSelfLoopInvalidQueryThrowsException() @Test(expected = IllegalStateException.class) public void testLoopInvalidQueryThrowsException() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 - 1 // | | // --- controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(2) - .addVertex(0, 1) - .addVertex(1, 0) + .addEdge(0, 1) + .addEdge(1, 0) .getQueryDefinitionBuilder() .build() ); @@ -390,15 +514,15 @@ public void testLoopInvalidQueryThrowsException() @Test public void testMarkSuccessfulTerminalStagesAsFinished() { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(1); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(); // 0 1 // \ / // 2 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(3) - .addVertex(0, 2) - .addVertex(1, 2) + .addEdge(0, 2) + .addEdge(1, 2) .getQueryDefinitionBuilder() .build() ); @@ -409,8 +533,8 @@ public void testMarkSuccessfulTerminalStagesAsFinished() controllerQueryKernelTester.init(); - Assert.assertTrue(controllerQueryKernelTester.isDone()); - Assert.assertTrue(controllerQueryKernelTester.isSuccess()); + Assert.assertFalse(controllerQueryKernelTester.isDone()); + Assert.assertFalse(controllerQueryKernelTester.isSuccess()); controllerQueryKernelTester.assertStagePhase(0, ControllerStagePhase.FINISHED); controllerQueryKernelTester.assertStagePhase(1, ControllerStagePhase.RESULTS_READY); @@ -430,4 +554,18 @@ private static void transitionNewToResultsComplete(ControllerQueryKernelTester q queryKernelTester.setResultsCompleteForStageAndWorkers(stageNumber, 0); } + private static void transitionNewToDoneReadingInput(ControllerQueryKernelTester queryKernelTester, int stageNumber) + { + queryKernelTester.startStage(stageNumber); + queryKernelTester.startWorkOrder(stageNumber); + queryKernelTester.doneReadingInput(stageNumber); + } + + private static void transitionDoneReadingInputToResultsComplete( + ControllerQueryKernelTester queryKernelTester, + int stageNumber + ) + { + queryKernelTester.setResultsCompleteForStageAndWorkers(stageNumber, 0); + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java new file mode 100644 index 000000000000..b6bb5bb3d4e0 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ControllerQueryKernelUtilsTest.java @@ -0,0 +1,551 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.kernel.controller; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.ShuffleKind; +import org.apache.druid.msq.kernel.StageId; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.stream.Collectors; + +public class ControllerQueryKernelUtilsTest +{ + @Test + public void test_computeStageGroups_multiPronged() + { + final QueryDefinition queryDef = makeMultiProngedQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 4), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 5), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 6) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_multiPronged_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeMultiProngedQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2, 4), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3, 5), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 6) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle_faultTolerant() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle_faultTolerant() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle_faultTolerant_durableResults() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(DurableStorageMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle_faultTolerant_durableResults() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(DurableStorageMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithoutShuffle_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithoutShuffle(); + + Assert.assertEquals( + // Without a sort-based shuffle, we can't leapfrog, so we launch two groups broken up by LOCAL_STORAGE + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 2, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_linearWithShuffle_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeLinearQueryDefinitionWithShuffle(); + + Assert.assertEquals( + // With sort-based shuffle, we can leapfrog 4 stages, all of them being in-memory + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanInWithBroadcast() + { + final QueryDefinition queryDef = makeFanInQueryDefinitionWithBroadcast(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(1) + .pipeline(false) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn_faultTolerant() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn_faultTolerant_durableResults() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(false) + .faultTolerance(true) + .durableStorage(true) + .destination(DurableStorageMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanIn_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeFanInQueryDefinition(); + + Assert.assertEquals( + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 2, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + @Test + public void test_computeStageGroups_fanInWithBroadcast_pipeline_twoAtOnce() + { + final QueryDefinition queryDef = makeFanInQueryDefinitionWithBroadcast(); + + Assert.assertEquals( + // Output of stage 1 is broadcast, so it must run first; then stages 0 and 2 may be launched together + ImmutableList.of( + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 1), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.LOCAL_STORAGE, 0, 2), + makeStageGroup(queryDef.getQueryId(), OutputChannelMode.MEMORY, 3) + ), + ControllerQueryKernelUtils.computeStageGroups( + queryDef, + ControllerQueryKernelConfig + .builder() + .maxRetainedPartitionSketchBytes(1) + .maxConcurrentStages(2) + .pipeline(true) + .faultTolerance(false) + .destination(TaskReportMSQDestination.instance()) + .build() + ) + ); + } + + private static QueryDefinition makeLinearQueryDefinitionWithShuffle() + { + // 0 -> 1 -> 2 -> 3 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 3) + .defineStage(0, ShuffleKind.GLOBAL_SORT) + .defineStage(1, ShuffleKind.GLOBAL_SORT) + .defineStage(2, ShuffleKind.GLOBAL_SORT) + .defineStage(3, ShuffleKind.GLOBAL_SORT) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeLinearQueryDefinitionWithoutShuffle() + { + // 0 -> 1 -> 2 -> 3 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 1) + .addEdge(1, 2) + .addEdge(2, 3) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeFanInQueryDefinition() + { + // 0 -> 2 -> 3 + // / + // 1 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(2, 3) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeFanInQueryDefinitionWithBroadcast() + { + // 0 -> 2 -> 3 + // / < broadcast + // 1 + + return new MockQueryDefinitionBuilder(4) + .addEdge(0, 2) + .addEdge(1, 2, true) + .addEdge(2, 3) + .getQueryDefinitionBuilder() + .build(); + } + + private static QueryDefinition makeMultiProngedQueryDefinition() + { + // 0 1 + // | / | + // 2 / 3 + // | | + // 4 5 + // \ / + // 6 + + return new MockQueryDefinitionBuilder(7) + .addEdge(0, 2) + .addEdge(1, 2) + .addEdge(1, 3) + .addEdge(2, 4) + .addEdge(3, 5) + .addEdge(4, 6) + .addEdge(5, 6) + .getQueryDefinitionBuilder() + .build(); + } + + public static StageGroup makeStageGroup( + final String queryId, + final OutputChannelMode outputChannelMode, + final int... stageNumbers + ) + { + return new StageGroup( + Arrays.stream(stageNumbers).mapToObj(n -> new StageId(queryId, n)).collect(Collectors.toList()), + outputChannelMode + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java index f16e35e6e283..6ac399f70e4a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/MockQueryDefinitionBuilder.java @@ -21,6 +21,9 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntBooleanPair; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.KeyColumn; import org.apache.druid.frame.key.KeyOrder; @@ -30,21 +33,26 @@ import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.kernel.FrameProcessorFactory; import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec; +import org.apache.druid.msq.kernel.HashShuffleSpec; +import org.apache.druid.msq.kernel.MixShuffleSpec; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinitionBuilder; +import org.apache.druid.msq.kernel.ShuffleKind; import org.apache.druid.msq.kernel.ShuffleSpec; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.mockito.Mockito; +import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; +import java.util.UUID; public class MockQueryDefinitionBuilder { @@ -56,14 +64,14 @@ public class MockQueryDefinitionBuilder private final int numStages; // Maps a stage to all the other stages on which it has dependency, i.e. for an edge like A -> B, the adjacency list - // would have an entry like B : [ A, ... ] - private final Map> adjacencyList = new HashMap<>(); + // would have an entry like B : [ , ... ] + private final Map> adjacencyList = new HashMap<>(); // Keeps a collection of those stages that have been already defined private final Set definedStages = new HashSet<>(); // Query definition builder corresponding to this mock builder - private final QueryDefinitionBuilder queryDefinitionBuilder = QueryDefinition.builder(); + private final QueryDefinitionBuilder queryDefinitionBuilder = QueryDefinition.builder(UUID.randomUUID().toString()); public MockQueryDefinitionBuilder(final int numStages) @@ -71,35 +79,40 @@ public MockQueryDefinitionBuilder(final int numStages) this.numStages = numStages; } - public MockQueryDefinitionBuilder addVertex(final int outEdge, final int inEdge) + public MockQueryDefinitionBuilder addEdge(final int outVertex, final int inVertex) + { + return addEdge(outVertex, inVertex, false); + } + + public MockQueryDefinitionBuilder addEdge(final int outVertex, final int inVertex, final boolean broadcast) { Preconditions.checkArgument( - outEdge < numStages, + outVertex < numStages, "vertex number can only be from 0 to one less than the total number of stages" ); Preconditions.checkArgument( - inEdge < numStages, + inVertex < numStages, "vertex number can only be from 0 to one less than the total number of stages" ); Preconditions.checkArgument( - !definedStages.contains(inEdge), - StringUtils.format("%s is already defined, cannot create more connections from it", inEdge) + !definedStages.contains(inVertex), + StringUtils.format("%s is already defined, cannot create more connections from it", inVertex) ); Preconditions.checkArgument( - !definedStages.contains(outEdge), - StringUtils.format("%s is already defined, cannot create more connections to it", outEdge) + !definedStages.contains(outVertex), + StringUtils.format("%s is already defined, cannot create more connections to it", outVertex) ); - adjacencyList.computeIfAbsent(inEdge, k -> new HashSet<>()).add(outEdge); + adjacencyList.computeIfAbsent(inVertex, k -> new HashSet<>()).add(IntBooleanPair.of(outVertex, broadcast)); return this; } public MockQueryDefinitionBuilder defineStage( int stageNumber, - boolean shuffling, + @Nullable ShuffleKind shuffleKind, int maxWorkers ) { @@ -113,27 +126,60 @@ public MockQueryDefinitionBuilder defineStage( ); definedStages.add(stageNumber); - ShuffleSpec shuffleSpec; + ShuffleSpec shuffleSpec = null; - if (shuffling) { - shuffleSpec = new GlobalSortMaxCountShuffleSpec( - new ClusterBy( - ImmutableList.of( - new KeyColumn(SHUFFLE_KEY_COLUMN, KeyOrder.ASCENDING) + if (shuffleKind != null) { + switch (shuffleKind) { + case GLOBAL_SORT: + shuffleSpec = new GlobalSortMaxCountShuffleSpec( + new ClusterBy( + ImmutableList.of( + new KeyColumn(SHUFFLE_KEY_COLUMN, KeyOrder.ASCENDING) + ), + 0 ), - 0 - ), - MAX_NUM_PARTITIONS, - false - ); - } else { - shuffleSpec = null; + MAX_NUM_PARTITIONS, + false + ); + break; + + case HASH_LOCAL_SORT: + case HASH: + shuffleSpec = new HashShuffleSpec( + new ClusterBy( + ImmutableList.of( + new KeyColumn( + SHUFFLE_KEY_COLUMN, + shuffleKind == ShuffleKind.HASH ? KeyOrder.NONE : KeyOrder.ASCENDING + ) + ), + 0 + ), + MAX_NUM_PARTITIONS + ); + break; + + case MIX: + shuffleSpec = MixShuffleSpec.instance(); + break; + } + + if (shuffleSpec == null || shuffleKind != shuffleSpec.kind()) { + throw new ISE("Oops, created an incorrect shuffleSpec[%s] for kind[%s]", shuffleSpec, shuffleKind); + } } - final List inputSpecs = - adjacencyList.getOrDefault(stageNumber, new HashSet<>()) - .stream() - .map(StageInputSpec::new).collect(Collectors.toList()); + final List inputSpecs = new ArrayList<>(); + final IntSet broadcastInputNumbers = new IntOpenHashSet(); + + int inputNumber = 0; + for (final IntBooleanPair pair : adjacencyList.getOrDefault(stageNumber, Collections.emptySet())) { + inputSpecs.add(new StageInputSpec(pair.leftInt())); + if (pair.rightBoolean()) { + broadcastInputNumbers.add(inputNumber); + } + inputNumber++; + } if (inputSpecs.isEmpty()) { for (int i = 0; i < maxWorkers; i++) { @@ -144,6 +190,7 @@ public MockQueryDefinitionBuilder defineStage( queryDefinitionBuilder.add( StageDefinition.builder(stageNumber) .inputs(inputSpecs) + .broadcastInputs(broadcastInputNumbers) .processorFactory(Mockito.mock(FrameProcessorFactory.class)) .shuffleSpec(shuffleSpec) .signature(RowSignature.builder().add(SHUFFLE_KEY_COLUMN, ColumnType.STRING).build()) @@ -153,14 +200,14 @@ public MockQueryDefinitionBuilder defineStage( return this; } - public MockQueryDefinitionBuilder defineStage(int stageNumber, boolean shuffling) + public MockQueryDefinitionBuilder defineStage(int stageNumber, @Nullable ShuffleKind shuffleKind) { - return defineStage(stageNumber, shuffling, 1); + return defineStage(stageNumber, shuffleKind, 1); } public MockQueryDefinitionBuilder defineStage(int stageNumber) { - return defineStage(stageNumber, false); + return defineStage(stageNumber, null); } public QueryDefinitionBuilder getQueryDefinitionBuilder() @@ -205,8 +252,8 @@ private boolean checkAcyclic(int node, Map visited) return false; } else { visited.put(node, StageState.VISITING); - for (int neighbour : adjacencyList.getOrDefault(node, Collections.emptySet())) { - if (!checkAcyclic(neighbour, visited)) { + for (IntBooleanPair neighbour : adjacencyList.getOrDefault(node, Collections.emptySet())) { + if (!checkAcyclic(neighbour.leftInt(), visited)) { return false; } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java index d408b47da8e1..fb5af9e7c4f6 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/NonShufflingWorkersWithRetryKernelTest.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.kernel.controller; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.junit.Assert; import org.junit.Test; @@ -318,13 +319,20 @@ public void testMultipleWorkersFailedBeforeAllResultsRecieved() @Nonnull private ControllerQueryKernelTester getSimpleQueryDefinition(int numWorkers) { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(numWorkers); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder + .destination(DurableStorageMSQDestination.instance()) + .durableStorage(true) + .faultTolerance(true) + .build() + ); // 0 -> 1 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(2) - .addVertex(0, 1) - .defineStage(0, false, numWorkers) - .defineStage(1, false, numWorkers) + .addEdge(0, 1) + .defineStage(0, null, numWorkers) + .defineStage(1, null, numWorkers) .getQueryDefinitionBuilder() .build() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java index 824c23b4fb11..81addb183ba3 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/ShufflingWorkersWithRetryKernelTest.java @@ -19,6 +19,8 @@ package org.apache.druid.msq.kernel.controller; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.kernel.ShuffleKind; import org.junit.Assert; import org.junit.Test; @@ -1071,13 +1073,20 @@ public void testMultipleWorkersFailedBeforeAllResultsReceived() @Nonnull private ControllerQueryKernelTester getSimpleQueryDefinition(int numWorkers) { - ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel(numWorkers); + ControllerQueryKernelTester controllerQueryKernelTester = testControllerQueryKernel( + configBuilder -> + configBuilder + .destination(DurableStorageMSQDestination.instance()) + .durableStorage(true) + .faultTolerance(true) + .build() + ); // 0 -> 1 controllerQueryKernelTester.queryDefinition( new MockQueryDefinitionBuilder(2) - .addVertex(0, 1) - .defineStage(0, true, numWorkers) - .defineStage(1, true, numWorkers) + .addEdge(0, 1) + .defineStage(0, ShuffleKind.GLOBAL_SORT, numWorkers) + .defineStage(1, ShuffleKind.GLOBAL_SORT, numWorkers) .getQueryDefinitionBuilder() .build() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java index 00ccfdee6c19..605e0bf2de74 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java @@ -29,6 +29,7 @@ import it.unimi.dsi.fastutil.longs.LongList; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpecSlicer; @@ -238,7 +239,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMa stageDef, new Int2IntAVLTreeMap(ImmutableMap.of(0, 2)), new StageInputSpecSlicer( - new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))) + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))), + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE)) ), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER @@ -251,7 +253,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMa Collections.singletonList( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 2})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 2})), + OutputChannelMode.LOCAL_STORAGE ) ) ) @@ -260,7 +263,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_fourWorkerMa Collections.singletonList( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{1})), + OutputChannelMode.LOCAL_STORAGE ) ) ) @@ -283,7 +287,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_oneWorkerMax stageDef, new Int2IntAVLTreeMap(ImmutableMap.of(0, 2)), new StageInputSpecSlicer( - new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))) + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, ReadablePartitions.striped(0, 2, 3))), + new Int2ObjectAVLTreeMap<>(ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE)) ), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER @@ -296,7 +301,8 @@ public void test_auto_oneInputStageWithThreePartitionsAndTwoWorkers_oneWorkerMax Collections.singletonList( new StageInputSlice( 0, - new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 1, 2})) + new StripedReadablePartitions(0, 2, new IntAVLTreeSet(new int[]{0, 1, 2})), + OutputChannelMode.LOCAL_STORAGE ) ) ) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java index b80e59223f78..2da5fd42caf1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java @@ -37,8 +37,6 @@ import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; -import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.http.client.response.StringFullResponseHolder; import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterSnapshots; @@ -247,8 +245,8 @@ public class SqlStatementResourceTest extends MSQTestBase ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) - + ImmutableMap.of(0, 1), + ImmutableMap.of() ), CounterSnapshotsTree.fromMap(ImmutableMap.of( 0, @@ -287,9 +285,7 @@ public class SqlStatementResourceTest extends MSQTestBase SqlTypeName.VARCHAR, SqlTypeName.VARCHAR ), - Yielders.each( - Sequences.simple( - RESULT_ROWS)), + RESULT_ROWS, null ) ) @@ -315,6 +311,7 @@ public class SqlStatementResourceTest extends MSQTestBase ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), + ImmutableMap.of(), ImmutableMap.of() ), new CounterSnapshotsTree(), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java new file mode 100644 index 000000000000..254eb8a23210 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/SendPartialKeyStatisticsInformationSerdeTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.statistics; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class SendPartialKeyStatisticsInformationSerdeTest +{ + private ObjectMapper objectMapper; + + @Before + public void setUp() + { + objectMapper = TestHelper.makeJsonMapper(); + objectMapper.registerModules(new MSQIndexingModule().getJacksonModules()); + objectMapper.enable(JsonParser.Feature.STRICT_DUPLICATE_DETECTION); + } + + @Test + public void testSerde() throws JsonProcessingException + { + PartialKeyStatisticsInformation partialInformation = new PartialKeyStatisticsInformation( + ImmutableSet.of(2L, 3L), + false, + 0.0 + ); + + final String json = objectMapper.writeValueAsString(partialInformation); + final PartialKeyStatisticsInformation deserializedKeyStatistics = objectMapper.readValue( + json, + PartialKeyStatisticsInformation.class + ); + Assert.assertEquals(json, partialInformation.getTimeSegments(), deserializedKeyStatistics.getTimeSegments()); + Assert.assertEquals(json, partialInformation.hasMultipleValues(), deserializedKeyStatistics.hasMultipleValues()); + Assert.assertEquals(json, partialInformation.getBytesRetained(), deserializedKeyStatistics.getBytesRetained(), 0); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java index fc7cfe5d9bea..2ebe975c39d9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteMSQTestsHelper.java @@ -32,6 +32,7 @@ import org.apache.druid.data.input.impl.LongDimensionSchema; import org.apache.druid.data.input.impl.StringDimensionSchema; import org.apache.druid.discovery.NodeRole; +import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.guice.GuiceInjectors; import org.apache.druid.guice.IndexingServiceTuningConfigModule; import org.apache.druid.guice.JoinableFactoryModule; @@ -175,6 +176,7 @@ public String getFormatString() groupByBuffers ).getGroupingEngine(); binder.bind(GroupingEngine.class).toInstance(groupingEngine); + binder.bind(Bouncer.class).toInstance(new Bouncer(1)); }; return ImmutableList.of( customBindings, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java index d59bf6f027be..c249df61ebab 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java @@ -137,7 +137,14 @@ public SqlEngine createEngine( ) { final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance(WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, 2, 10, 2, 0, 0); + WorkerMemoryParameters.createInstance( + WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, + 2, + 10, + 2, + 0, + 0 + ); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 3b5e14cb2f55..fe78b481bee4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -46,6 +46,7 @@ import org.apache.druid.discovery.BrokerClient; import org.apache.druid.discovery.NodeRole; import org.apache.druid.frame.channel.FrameChannelSequence; +import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.frame.testutil.FrameTestUtil; import org.apache.druid.guice.DruidInjectorBuilder; import org.apache.druid.guice.DruidSecondaryModule; @@ -58,7 +59,6 @@ import org.apache.druid.guice.SegmentWranglerModule; import org.apache.druid.guice.StartupInjectorBuilder; import org.apache.druid.guice.annotations.EscalatedGlobal; -import org.apache.druid.guice.annotations.MSQ; import org.apache.druid.guice.annotations.Self; import org.apache.druid.hll.HyperLogLogCollector; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; @@ -73,7 +73,6 @@ import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; -import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.http.client.Request; @@ -86,6 +85,7 @@ import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.DataServerQueryHandler; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.guice.MSQDurableStorageModule; import org.apache.druid.msq.guice.MSQExternalDataSourceModule; @@ -504,7 +504,9 @@ public String getFormatString() // following bindings are overriding other bindings that end up needing a lot more dependencies. // We replace the bindings with something that returns null to make things more brittle in case they // actually are used somewhere in the test. - binder.bind(SqlStatementFactory.class).annotatedWith(MSQ.class).toProvider(Providers.of(null)); + binder.bind(SqlStatementFactory.class) + .annotatedWith(MultiStageQuery.class) + .toProvider(Providers.of(null)); binder.bind(SqlToolbox.class).toProvider(Providers.of(null)); binder.bind(MSQTaskSqlEngine.class).toProvider(Providers.of(null)); } @@ -514,7 +516,8 @@ public String getFormatString() new LookylooModule(), new SegmentWranglerModule(), new HllSketchModule(), - binder -> binder.bind(BrokerClient.class).toInstance(brokerClient) + binder -> binder.bind(BrokerClient.class).toInstance(brokerClient), + binder -> binder.bind(Bouncer.class).toInstance(new Bouncer(1)) ); // adding node role injection to the modules, since CliPeon would also do that through run method Injector injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) @@ -835,20 +838,7 @@ public static List getRows(@Nullable MSQResultsReport resultsReport) if (resultsReport == null) { return null; } else { - Yielder yielder = resultsReport.getResultYielder(); - List rows = new ArrayList<>(); - while (!yielder.isDone()) { - rows.add(yielder.get()); - yielder = yielder.next(null); - } - try { - yielder.close(); - } - catch (IOException e) { - throw new ISE("Unable to get results from the report"); - } - - return rows; + return resultsReport.getResults(); } } @@ -1436,9 +1426,10 @@ public Pair, List>> pageInformation.getWorker() == null ? 0 : pageInformation.getWorker(), pageInformation.getPartition() == null ? 0 : pageInformation.getPartition() )).flatMap(frame -> SqlStatementResourceHelper.getResultSequence( - msqControllerTask, - finalStage, frame, + finalStage.getFrameReader(), + msqControllerTask.getQuerySpec().getColumnMappings(), + new ResultsContext(msqControllerTask.getSqlTypeNames(), msqControllerTask.getSqlResultsContext()), objectMapper )).withBaggage(closer).toList()); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java index 3e78e477bda9..96e26cba77e1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java @@ -54,6 +54,12 @@ public void postPartialKeyStatistics( } } + @Override + public void postDoneReadingInput(StageId stageId, int workerNumber) + { + controller.doneReadingInput(stageId.getStageNumber(), workerNumber); + } + @Override public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 45de3d7c4f50..20d31fbd4cfe 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.test; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -28,33 +29,47 @@ import com.google.inject.Injector; import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.coordinator.CoordinatorClient; +import org.apache.druid.client.indexing.NoopOverlordClient; +import org.apache.druid.client.indexing.TaskStatusResponse; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.indexer.RunnerTaskState; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatus; -import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.indexer.TaskStatusPlus; import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerMemoryParameters; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.exec.WorkerFailureListener; import org.apache.druid.msq.exec.WorkerImpl; -import org.apache.druid.msq.exec.WorkerManagerClient; +import org.apache.druid.msq.exec.WorkerManager; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.exec.WorkerStorageParameters; +import org.apache.druid.msq.indexing.IndexerControllerContext; +import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQWorkerTask; +import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.table.TableInputSpecSlicer; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.QueryContext; +import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.server.DruidNode; -import org.apache.druid.server.metrics.NoopServiceEmitter; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import javax.annotation.Nullable; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -72,7 +87,8 @@ public class MSQTestControllerContext implements ControllerContext private final ConcurrentMap statusMap = new ConcurrentHashMap<>(); private final ListeningExecutorService executor = MoreExecutors.listeningDecorator(Execs.multiThreaded( NUM_WORKERS, - "MultiStageQuery-test-controller-client")); + "MultiStageQuery-test-controller-client" + )); private final CoordinatorClient coordinatorClient; private final DruidNode node = new DruidNode( "controller", @@ -85,18 +101,18 @@ public class MSQTestControllerContext implements ControllerContext ); private final Injector injector; private final ObjectMapper mapper; - private final ServiceEmitter emitter = new NoopServiceEmitter(); private Controller controller; - private TaskReport.ReportMap report = null; private final WorkerMemoryParameters workerMemoryParameters; + private final QueryContext queryContext; public MSQTestControllerContext( ObjectMapper mapper, Injector injector, TaskActionClient taskActionClient, WorkerMemoryParameters workerMemoryParameters, - List loadedSegments + List loadedSegments, + QueryContext queryContext ) { this.mapper = mapper; @@ -105,8 +121,8 @@ public MSQTestControllerContext( coordinatorClient = Mockito.mock(CoordinatorClient.class); Mockito.when(coordinatorClient.fetchServerViewSegments( - ArgumentMatchers.anyString(), - ArgumentMatchers.any() + ArgumentMatchers.anyString(), + ArgumentMatchers.any() ) ).thenAnswer(invocation -> loadedSegments.stream() .filter(immutableSegmentLoadInfo -> @@ -116,13 +132,15 @@ public MSQTestControllerContext( .collect(Collectors.toList()) ); this.workerMemoryParameters = workerMemoryParameters; + this.queryContext = queryContext; } - WorkerManagerClient workerManagerClient = new WorkerManagerClient() + OverlordClient overlordClient = new NoopOverlordClient() { @Override - public String run(String taskId, MSQWorkerTask task) + public ListenableFuture runTask(String taskId, Object taskObject) { + final MSQWorkerTask task = (MSQWorkerTask) taskObject; if (controller == null) { throw new ISE("Controller needs to be set using the register method"); } @@ -137,13 +155,26 @@ public String run(String taskId, MSQWorkerTask task) Worker worker = new WorkerImpl( task, - new MSQTestWorkerContext(inMemoryWorkers, controller, mapper, injector, workerMemoryParameters), + new MSQTestWorkerContext( + inMemoryWorkers, + controller, + mapper, + injector, + workerMemoryParameters + ), workerStorageParameters ); inMemoryWorkers.put(task.getId(), worker); statusMap.put(task.getId(), TaskStatus.running(task.getId())); - ListenableFuture future = executor.submit(worker::run); + ListenableFuture future = executor.submit(() -> { + try { + return worker.run(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }); Futures.addCallback(future, new FutureCallback() { @@ -161,11 +192,11 @@ public void onFailure(Throwable t) } }, MoreExecutors.directExecutor()); - return task.getId(); + return Futures.immediateFuture(null); } @Override - public Map statuses(Set taskIds) + public ListenableFuture> taskStatuses(Set taskIds) { Map result = new HashMap<>(); for (String taskId : taskIds) { @@ -188,40 +219,63 @@ public Map statuses(Set taskIds) } } } - return result; + return Futures.immediateFuture(result); } @Override - public TaskLocation location(String workerId) + public ListenableFuture taskStatus(String taskId) { - final TaskStatus status = statusMap.get(workerId); - if (status != null && status.getStatusCode().equals(TaskState.RUNNING) && inMemoryWorkers.containsKey(workerId)) { - return TaskLocation.create("host-" + workerId, 1, -1); + final Map taskStatusMap = + FutureUtils.getUnchecked(taskStatuses(Collections.singleton(taskId)), true); + + final TaskStatus taskStatus = taskStatusMap.get(taskId); + if (taskStatus == null) { + return Futures.immediateFuture(new TaskStatusResponse(taskId, null)); } else { - return TaskLocation.unknown(); + return Futures.immediateFuture( + new TaskStatusResponse( + taskId, + new TaskStatusPlus( + taskStatus.getId(), + null, + null, + DateTimes.utc(0), + DateTimes.utc(0), + taskStatus.getStatusCode(), + taskStatus.getStatusCode(), + taskStatus.getStatusCode().isRunnable() ? RunnerTaskState.RUNNING : RunnerTaskState.NONE, + null, + taskStatus.getStatusCode().isRunnable() + ? TaskLocation.create("host-" + taskId, 1, -1) + : TaskLocation.unknown(), + null, + taskStatus.getErrorMsg() + ) + ) + ); } } @Override - public void cancel(String workerId) + public ListenableFuture cancelTask(String workerId) { final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { worker.stopGracefully(); } - } - - @Override - public void close() - { - //do nothing + return Futures.immediateFuture(null); } }; @Override - public ServiceEmitter emitter() + public ControllerQueryKernelConfig queryKernelConfig(MSQSpec querySpec, QueryDefinition queryDef) + { + return IndexerControllerContext.makeQueryKernelConfig(querySpec, new ControllerMemoryParameters(100_000_000)); + } + + @Override + public void emitMetric(String metric, Number value) { - return emitter; } @Override @@ -243,21 +297,37 @@ public DruidNode selfNode() } @Override - public CoordinatorClient coordinatorClient() + public TaskActionClient taskActionClient() { - return coordinatorClient; + return taskActionClient; } @Override - public TaskActionClient taskActionClient() + public InputSpecSlicer newTableInputSpecSlicer() { - return taskActionClient; + return new TableInputSpecSlicer( + coordinatorClient, + taskActionClient, + MultiStageQueryContext.getSegmentSources(queryContext) + ); } @Override - public WorkerManagerClient workerManager() + public WorkerManager newWorkerManager( + String queryId, + MSQSpec querySpec, + ControllerQueryKernelConfig queryKernelConfig, + WorkerFailureListener workerFailureListener + ) { - return workerManagerClient; + return new MSQWorkerTaskLauncher( + controller.queryId(), + "test-datasource", + overlordClient, + workerFailureListener, + IndexerControllerContext.makeTaskContext(querySpec, queryKernelConfig, ImmutableMap.of()), + 0 + ); } @Override @@ -267,21 +337,8 @@ public void registerController(Controller controller, Closer closer) } @Override - public WorkerClient taskClientFor(Controller controller) + public WorkerClient newWorkerClient() { return new MSQTestWorkerClient(inMemoryWorkers); } - - @Override - public void writeReports(String controllerTaskId, TaskReport.ReportMap taskReport) - { - if (controller != null && controller.id().equals(controllerTaskId)) { - report = taskReport; - } - } - - public TaskReport.ReportMap getAllReports() - { - return report; - } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java index 4a5ac7e84e64..a565283154fd 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java @@ -20,10 +20,13 @@ package org.apache.druid.msq.test; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Injector; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.client.ImmutableSegmentLoadInfo; import org.apache.druid.client.indexing.NoopOverlordClient; import org.apache.druid.client.indexing.TaskPayloadResponse; @@ -36,11 +39,19 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerImpl; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.exec.ResultsContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; import org.apache.druid.msq.indexing.MSQControllerTask; +import org.apache.druid.msq.indexing.destination.MSQDestination; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; import org.joda.time.DateTime; import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -80,33 +91,53 @@ public MSQTestOverlordServiceClient( @Override public ListenableFuture runTask(String taskId, Object taskObject) { + TestQueryListener queryListener = null; ControllerImpl controller = null; - MSQTestControllerContext msqTestControllerContext = null; + MSQTestControllerContext msqTestControllerContext; try { + MSQControllerTask cTask = objectMapper.convertValue(taskObject, MSQControllerTask.class); + msqTestControllerContext = new MSQTestControllerContext( objectMapper, injector, taskActionClient, workerMemoryParameters, - loadedSegmentMetadata + loadedSegmentMetadata, + cTask.getQuerySpec().getQuery().context() ); - MSQControllerTask cTask = objectMapper.convertValue(taskObject, MSQControllerTask.class); inMemoryControllerTask.put(cTask.getId(), cTask); - controller = new ControllerImpl(cTask, msqTestControllerContext); + controller = new ControllerImpl( + cTask.getId(), + cTask.getQuerySpec(), + new ResultsContext(cTask.getSqlTypeNames(), cTask.getSqlResultsContext()), + msqTestControllerContext + ); + + inMemoryControllers.put(controller.queryId(), controller); - inMemoryControllers.put(controller.id(), controller); + queryListener = + new TestQueryListener( + cTask.getId(), + cTask.getQuerySpec().getDestination() + ); - inMemoryTaskStatus.put(taskId, controller.run()); + try { + controller.run(queryListener); + inMemoryTaskStatus.put(taskId, queryListener.getStatusReport().toTaskStatus(cTask.getId())); + } + catch (Exception e) { + inMemoryTaskStatus.put(taskId, TaskStatus.failure(cTask.getId(), e.toString())); + } return Futures.immediateFuture(null); } catch (Exception e) { throw new ISE(e, "Unable to run"); } finally { - if (controller != null && msqTestControllerContext != null) { - reports.put(controller.id(), msqTestControllerContext.getAllReports()); + if (controller != null && queryListener != null) { + reports.put(controller.queryId(), queryListener.getReportMap()); } } } @@ -114,7 +145,7 @@ public ListenableFuture runTask(String taskId, Object taskObject) @Override public ListenableFuture cancelTask(String taskId) { - inMemoryControllers.get(taskId).stopGracefully(); + inMemoryControllers.get(taskId).stop(); return Futures.immediateFuture(null); } @@ -166,4 +197,96 @@ MSQControllerTask getMSQControllerTask(String id) { return inMemoryControllerTask.get(id); } + + /** + * Listener that captures a report and makes it available through {@link #getReportMap()}. + */ + static class TestQueryListener implements QueryListener + { + private final String taskId; + private final MSQDestination destination; + private final List results = new ArrayList<>(); + + private List signature; + private List sqlTypeNames; + private boolean resultsTruncated = true; + private TaskReport.ReportMap reportMap; + + public TestQueryListener(final String taskId, final MSQDestination destination) + { + this.taskId = taskId; + this.destination = destination; + } + + @Override + public boolean readResults() + { + return destination.getRowsInTaskReport() == MSQDestination.UNLIMITED || destination.getRowsInTaskReport() > 0; + } + + @Override + public void onResultsStart(List signature, @Nullable List sqlTypeNames) + { + this.signature = signature; + this.sqlTypeNames = sqlTypeNames; + } + + @Override + public boolean onResultRow(Object[] row) + { + if (destination.getRowsInTaskReport() == MSQDestination.UNLIMITED + || results.size() < destination.getRowsInTaskReport()) { + results.add(row); + return true; + } else { + return false; + } + } + + @Override + public void onResultsComplete() + { + resultsTruncated = false; + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + final MSQResultsReport resultsReport; + + if (signature != null) { + resultsReport = new MSQResultsReport( + signature, + sqlTypeNames, + results, + resultsTruncated + ); + } else { + resultsReport = null; + } + + final MSQTaskReport taskReport = new MSQTaskReport( + taskId, + new MSQTaskReportPayload( + report.getStatus(), + report.getStages(), + report.getCounters(), + resultsReport + ) + ); + + reportMap = TaskReport.buildTaskReports(taskReport); + } + + public TaskReport.ReportMap getReportMap() + { + return Preconditions.checkNotNull(reportMap, "reportMap"); + } + + public MSQStatusReport getStatusReport() + { + final MSQTaskReport taskReport = (MSQTaskReport) Iterables.getOnlyElement(getReportMap().values()); + return taskReport.getPayload().getStatus(); + } + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index ae892c34500a..72cb246a43e1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -54,24 +54,22 @@ public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workO @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerTaskId, - String queryId, - int stageNumber + StageId stageId ) { - StageId stageId = new StageId(queryId, stageNumber); return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshot(stageId)); } @Override public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerTaskId, - String queryId, - int stageNumber, + StageId stageId, long timeChunk ) { - StageId stageId = new StageId(queryId, stageNumber); - return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk)); + return Futures.immediateFuture( + inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk) + ); } @Override @@ -123,20 +121,19 @@ public ListenableFuture fetchChannelData( final ReadableByteChunksFrameChannel channel ) { - try (InputStream inputStream = inMemoryWorkers.get(workerTaskId).readChannel( - stageId.getQueryId(), - stageId.getStageNumber(), - partitionNumber, - offset - )) { + try (InputStream inputStream = + inMemoryWorkers.get(workerTaskId) + .readChannel(stageId.getQueryId(), stageId.getStageNumber(), partitionNumber, offset)) { byte[] buffer = new byte[8 * 1024]; + boolean didRead = false; int bytesRead; while ((bytesRead = inputStream.read(buffer)) != -1) { channel.addChunk(Arrays.copyOf(buffer, bytesRead)); + didRead = true; } inputStream.close(); - return Futures.immediateFuture(true); + return Futures.immediateFuture(!didRead); } catch (Exception e) { throw new ISE(e, "Error reading frame file channel"); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index d2283a94be04..ad05c20b5829 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -22,9 +22,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexer.report.TaskReportFileWriter; import org.apache.druid.indexing.common.TaskToolbox; +import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.msq.exec.Controller; @@ -123,20 +123,7 @@ public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) OffHeapMemorySegmentWriteOutMediumFactory.instance(), true ); - final TaskReportFileWriter reportFileWriter = new TaskReportFileWriter() - { - @Override - public void write(String taskId, TaskReport.ReportMap reports) - { - - } - - @Override - public void setObjectMapper(ObjectMapper objectMapper) - { - - } - }; + final TaskReportFileWriter reportFileWriter = new NoopTestTaskReportFileWriter(); return new IndexerFrameContext( new IndexerWorkerContext( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java new file mode 100644 index 000000000000..fe38819a4519 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/NoopQueryListener.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.test; + +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; + +import javax.annotation.Nullable; +import java.util.List; + +public class NoopQueryListener implements QueryListener +{ + @Override + public boolean readResults() + { + return false; + } + + @Override + public void onResultsStart(List signature, @Nullable List sqlTypeNames) + { + // Do nothing. + } + + @Override + public boolean onResultRow(Object[] row) + { + return true; + } + + @Override + public void onResultsComplete() + { + // Do nothing. + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + // Do nothing. + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java index 3c14f4f1cd9f..1966d1e5b10a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java @@ -27,6 +27,7 @@ import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterSnapshots; import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; @@ -77,7 +78,8 @@ public void testDistinctPartitionsOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 3), - ImmutableMap.of(0, 15) + ImmutableMap.of(0, 15), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -117,7 +119,8 @@ public void testOnePartitionOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 4) + ImmutableMap.of(0, 4), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -158,7 +161,8 @@ public void testCommonPartitionsOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 21) + ImmutableMap.of(0, 21), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = @@ -197,7 +201,8 @@ public void testNullChannelCounters() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 21) + ImmutableMap.of(0, 21), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -237,7 +242,8 @@ public void testConsecutivePartitionsOnEachWorker() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 4), - ImmutableMap.of(0, 13) + ImmutableMap.of(0, 13), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null); Optional> pages = SqlStatementResourceHelper.populatePageList( @@ -278,7 +284,8 @@ public void testEmptyCountersForDurableStorageDestination() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) + ImmutableMap.of(0, 1), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null @@ -315,7 +322,8 @@ public void testEmptyCountersForTaskReportDestination() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) + ImmutableMap.of(0, 1), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null @@ -354,7 +362,8 @@ public void testEmptyCountersForDataSourceDestination() ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(0, 1), - ImmutableMap.of(0, 1) + ImmutableMap.of(0, 1), + ImmutableMap.of(0, OutputChannelMode.LOCAL_STORAGE) ), counterSnapshots, null diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java index 28cc1ae2af5e..865a8593a7d3 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriter.java @@ -24,10 +24,13 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexer.report.TaskReportFileWriter; import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; import java.io.File; -import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; import java.util.HashMap; import java.util.Map; @@ -41,26 +44,29 @@ public class MultipleFileTaskReportFileWriter implements TaskReportFileWriter @Override public void write(String taskId, TaskReport.ReportMap reports) + { + try (final OutputStream outputStream = openReportOutputStream(taskId)) { + SingleFileTaskReportFileWriter.writeReportToStream(objectMapper, outputStream, reports); + } + catch (Exception e) { + log.error(e, "Encountered exception in write()."); + } + } + + @Override + public OutputStream openReportOutputStream(String taskId) throws IOException { final File reportsFile = taskReportFiles.get(taskId); if (reportsFile == null) { - log.error("Could not find report file for task[%s]", taskId); - return; + throw new ISE("Could not find report file for task[%s]", taskId); } - try { - final File reportsFileParent = reportsFile.getParentFile(); - if (reportsFileParent != null) { - FileUtils.mkdirp(reportsFileParent); - } - - try (final FileOutputStream outputStream = new FileOutputStream(reportsFile)) { - SingleFileTaskReportFileWriter.writeReportToStream(objectMapper, outputStream, reports); - } - } - catch (Exception e) { - log.error(e, "Encountered exception in write()."); + final File reportsFileParent = reportsFile.getParentFile(); + if (reportsFileParent != null) { + FileUtils.mkdirp(reportsFileParent); } + + return Files.newOutputStream(reportsFile.toPath()); } @Override diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java new file mode 100644 index 000000000000..2e51973ec241 --- /dev/null +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/MultipleFileTaskReportFileWriterTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.indexing.common; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.indexer.report.IngestionStatsAndErrorsTaskReport; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Map; + +public class MultipleFileTaskReportFileWriterTest +{ + private static final String TASK_ID = "mytask"; + + @Rule + public final TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void testReport() throws IOException + { + final ObjectMapper mapper = TestHelper.makeJsonMapper(); + final File file = tempFolder.newFile(); + final MultipleFileTaskReportFileWriter writer = new MultipleFileTaskReportFileWriter(); + writer.setObjectMapper(mapper); + writer.add(TASK_ID, file); + + final TaskReport.ReportMap reportsMap = TaskReport.buildTaskReports( + new IngestionStatsAndErrorsTaskReport(TASK_ID, null) + ); + + writer.write(TASK_ID, reportsMap); + + Assert.assertEquals( + reportsMap, + mapper.readValue(Files.readAllBytes(file.toPath()), new TypeReference>() {}) + ); + } +} diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java new file mode 100644 index 000000000000..1381a7483cb2 --- /dev/null +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/SingleFileTaskReportFileWriterTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.indexing.common; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.indexer.report.IngestionStatsAndErrorsTaskReport; +import org.apache.druid.indexer.report.SingleFileTaskReportFileWriter; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Map; + +public class SingleFileTaskReportFileWriterTest +{ + private static final String TASK_ID = "mytask"; + + @Rule + public final TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void testReport() throws IOException + { + final ObjectMapper mapper = TestHelper.makeJsonMapper(); + final File file = tempFolder.newFile(); + final SingleFileTaskReportFileWriter writer = new SingleFileTaskReportFileWriter(file); + writer.setObjectMapper(mapper); + final TaskReport.ReportMap reportsMap = TaskReport.buildTaskReports( + new IngestionStatsAndErrorsTaskReport(TASK_ID, null) + ); + writer.write(TASK_ID, reportsMap); + Assert.assertEquals( + reportsMap, + mapper.readValue(Files.readAllBytes(file.toPath()), new TypeReference>() {}) + ); + } +} diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java index 7e7860e9d8e2..ad175faeb49b 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/NoopTestTaskReportFileWriter.java @@ -23,6 +23,9 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.indexer.report.TaskReportFileWriter; +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; + public class NoopTestTaskReportFileWriter implements TaskReportFileWriter { @Override @@ -30,6 +33,13 @@ public void write(String id, TaskReport.ReportMap reports) { } + @Override + public OutputStream openReportOutputStream(String taskId) + { + // Stream to nowhere. + return new ByteArrayOutputStream(); + } + @Override public void setObjectMapper(ObjectMapper objectMapper) { diff --git a/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java b/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java index 0fe486407db8..7ede24cd8f9a 100644 --- a/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java +++ b/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java @@ -26,7 +26,6 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.msq.indexing.report.MSQResultsReport; import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; @@ -253,13 +252,10 @@ public void testExport() throws Exception "Results report for the task id is empty" ); - Yielder yielder = resultsReport.getResultYielder(); List> actualResults = new ArrayList<>(); - while (!yielder.isDone()) { - Object[] row = yielder.get(); + for (final Object[] row : resultsReport.getResults()) { actualResults.add(Arrays.asList(row)); - yielder = yielder.next(null); } ImmutableList> expectedResults = ImmutableList.of( diff --git a/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java b/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java index c5fc437fc9c7..2a2386869e4a 100644 --- a/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java +++ b/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java @@ -32,7 +32,6 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.RetryUtils; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.http.client.response.StatusResponseHolder; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.indexing.report.MSQResultsReport; @@ -215,17 +214,14 @@ private void compareResults(String taskId, MsqQueryWithResults expectedQueryWith List> actualResults = new ArrayList<>(); - Yielder yielder = resultsReport.getResultYielder(); List rowSignature = resultsReport.getSignature(); - while (!yielder.isDone()) { - Object[] row = yielder.get(); + for (final Object[] row : resultsReport.getResults()) { Map rowWithFieldNames = new LinkedHashMap<>(); for (int i = 0; i < row.length; ++i) { rowWithFieldNames.put(rowSignature.get(i).getName(), row[i]); } actualResults.add(rowWithFieldNames); - yielder = yielder.next(null); } QueryResultVerifier.ResultVerificationObject resultsComparison = QueryResultVerifier.compareResults( diff --git a/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java b/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java index 3e5f2fe00a1f..168d96fc20ab 100644 --- a/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java +++ b/processing/src/main/java/org/apache/druid/frame/util/DurableStorageUtils.java @@ -126,7 +126,11 @@ public static String getTaskIdOutputsFolderName( { return StringUtils.format( "%s/taskId_%s", - getWorkerOutputFolderName(controllerTaskId, stageNumber, workerNumber), + getWorkerOutputFolderName( + IdUtils.validateId("controller task ID", controllerTaskId), + stageNumber, + workerNumber + ), taskId ); } diff --git a/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java b/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java index d862b224d86e..9012f4e83a15 100644 --- a/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java +++ b/processing/src/main/java/org/apache/druid/indexer/report/SingleFileTaskReportFileWriter.java @@ -24,8 +24,9 @@ import org.apache.druid.java.util.common.logger.Logger; import java.io.File; -import java.io.FileOutputStream; +import java.io.IOException; import java.io.OutputStream; +import java.nio.file.Files; public class SingleFileTaskReportFileWriter implements TaskReportFileWriter { @@ -42,21 +43,25 @@ public SingleFileTaskReportFileWriter(File reportsFile) @Override public void write(String taskId, TaskReport.ReportMap reports) { - try { - final File reportsFileParent = reportsFile.getParentFile(); - if (reportsFileParent != null) { - FileUtils.mkdirp(reportsFileParent); - } - - try (final FileOutputStream outputStream = new FileOutputStream(reportsFile)) { - writeReportToStream(objectMapper, outputStream, reports); - } + try (final OutputStream outputStream = openReportOutputStream(taskId)) { + writeReportToStream(objectMapper, outputStream, reports); } catch (Exception e) { log.error(e, "Encountered exception in write()."); } } + @Override + public OutputStream openReportOutputStream(String taskId) throws IOException + { + final File reportsFileParent = reportsFile.getParentFile(); + if (reportsFileParent != null) { + FileUtils.mkdirp(reportsFileParent); + } + + return Files.newOutputStream(reportsFile.toPath()); + } + @Override public void setObjectMapper(ObjectMapper objectMapper) { diff --git a/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java b/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java index bb3ebcd0394a..0cdd02493662 100644 --- a/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java +++ b/processing/src/main/java/org/apache/druid/indexer/report/TaskReportFileWriter.java @@ -21,9 +21,14 @@ import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.io.OutputStream; + public interface TaskReportFileWriter { void write(String taskId, TaskReport.ReportMap reports); + OutputStream openReportOutputStream(String taskId) throws IOException; + void setObjectMapper(ObjectMapper objectMapper); } diff --git a/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java b/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java index 224cfc78ed11..d6dde12fd6db 100644 --- a/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java +++ b/server/src/main/java/org/apache/druid/rpc/RequestBuilder.java @@ -32,8 +32,6 @@ import org.joda.time.Duration; import javax.ws.rs.core.MediaType; -import java.net.MalformedURLException; -import java.net.URL; import java.util.Arrays; import java.util.Map; import java.util.Objects; @@ -77,11 +75,11 @@ public RequestBuilder content(final String contentType, final byte[] content) return this; } - public RequestBuilder jsonContent(final ObjectMapper jsonMapper, final Object content) + public RequestBuilder objectContent(final ObjectMapper objectMapper, final String contentType, final Object content) { try { - this.contentType = MediaType.APPLICATION_JSON; - this.content = jsonMapper.writeValueAsBytes(Preconditions.checkNotNull(content, "content")); + this.contentType = contentType; + this.content = objectMapper.writeValueAsBytes(Preconditions.checkNotNull(content, "content")); return this; } catch (JsonProcessingException e) { @@ -89,16 +87,14 @@ public RequestBuilder jsonContent(final ObjectMapper jsonMapper, final Object co } } + public RequestBuilder jsonContent(final ObjectMapper jsonMapper, final Object content) + { + return objectContent(jsonMapper, MediaType.APPLICATION_JSON, content); + } + public RequestBuilder smileContent(final ObjectMapper smileMapper, final Object content) { - try { - this.contentType = SmileMediaTypes.APPLICATION_JACKSON_SMILE; - this.content = smileMapper.writeValueAsBytes(Preconditions.checkNotNull(content, "content")); - return this; - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + return objectContent(smileMapper, SmileMediaTypes.APPLICATION_JACKSON_SMILE, content); } public RequestBuilder timeout(final Duration timeout) @@ -121,8 +117,7 @@ public Duration getTimeout() public Request build(ServiceLocation serviceLocation) { // It's expected that our encodedPathAndQueryString starts with '/' and the service base path doesn't end with one. - final String path = serviceLocation.getBasePath() + encodedPathAndQueryString; - final Request request = new Request(method, makeURL(serviceLocation, path)); + final Request request = new Request(method, serviceLocation.toURL(encodedPathAndQueryString)); for (final Map.Entry entry : headers.entries()) { request.addHeader(entry.getKey(), entry.getValue()); @@ -135,29 +130,6 @@ public Request build(ServiceLocation serviceLocation) return request; } - private URL makeURL(final ServiceLocation serviceLocation, final String encodedPathAndQueryString) - { - final String scheme; - final int portToUse; - - if (serviceLocation.getTlsPort() > 0) { - // Prefer HTTPS if available. - scheme = "https"; - portToUse = serviceLocation.getTlsPort(); - } else { - scheme = "http"; - portToUse = serviceLocation.getPlaintextPort(); - } - - // Use URL constructor, not URI, since the path is already encoded. - try { - return new URL(scheme, serviceLocation.getHost(), portToUse, encodedPathAndQueryString); - } - catch (MalformedURLException e) { - throw new IllegalArgumentException(e); - } - } - @Override public boolean equals(Object o) { diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java index 3a092d7cb8dd..aeaa24318e93 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java @@ -26,7 +26,10 @@ import org.apache.druid.server.DruidNode; import org.apache.druid.server.coordination.DruidServerMetadata; +import javax.annotation.Nullable; import javax.validation.constraints.NotNull; +import java.net.MalformedURLException; +import java.net.URL; import java.util.Iterator; import java.util.Objects; @@ -35,11 +38,24 @@ */ public class ServiceLocation { + private static final String HTTP_SCHEME = "http"; + private static final String HTTPS_SCHEME = "https"; + private static final Splitter HOST_SPLITTER = Splitter.on(":").limit(2); + private final String host; private final int plaintextPort; private final int tlsPort; private final String basePath; + /** + * Create a service location. + * + * @param host hostname or address + * @param plaintextPort plaintext port + * @param tlsPort TLS port + * @param basePath base path; must be encoded and must not include trailing "/". In particular, to use root as + * the base path, pass "" for this parameter. + */ public ServiceLocation(final String host, final int plaintextPort, final int tlsPort, final String basePath) { this.host = Preconditions.checkNotNull(host, "host"); @@ -48,13 +64,19 @@ public ServiceLocation(final String host, final int plaintextPort, final int tls this.basePath = Preconditions.checkNotNull(basePath, "basePath"); } + /** + * Create a service location based on a {@link DruidNode}, without a base path. + */ public static ServiceLocation fromDruidNode(final DruidNode druidNode) { return new ServiceLocation(druidNode.getHost(), druidNode.getPlaintextPort(), druidNode.getTlsPort(), ""); } - private static final Splitter SPLITTER = Splitter.on(":").limit(2); - + /** + * Create a service location based on a {@link DruidServerMetadata}. + * + * @throws IllegalArgumentException if the server metadata cannot be mapped to a service location. + */ public static ServiceLocation fromDruidServerMetadata(final DruidServerMetadata druidServerMetadata) { final String host = getHostFromString( @@ -71,7 +93,7 @@ public static ServiceLocation fromDruidServerMetadata(final DruidServerMetadata private static String getHostFromString(@NotNull String s) { - Iterator iterator = SPLITTER.split(s).iterator(); + Iterator iterator = HOST_SPLITTER.split(s).iterator(); ImmutableList strings = ImmutableList.copyOf(iterator); return strings.get(0); } @@ -81,7 +103,7 @@ private static int getPortFromString(String s) if (s == null) { return -1; } - Iterator iterator = SPLITTER.split(s).iterator(); + Iterator iterator = HOST_SPLITTER.split(s).iterator(); ImmutableList strings = ImmutableList.copyOf(iterator); try { return Integer.parseInt(strings.get(1)); @@ -111,6 +133,33 @@ public String getBasePath() return basePath; } + public URL toURL(@Nullable final String encodedPathAndQueryString) + { + final String scheme; + final int portToUse; + + if (tlsPort > 0) { + // Prefer HTTPS if available. + scheme = HTTPS_SCHEME; + portToUse = tlsPort; + } else { + scheme = HTTP_SCHEME; + portToUse = plaintextPort; + } + + try { + return new URL( + scheme, + host, + portToUse, + basePath + (encodedPathAndQueryString == null ? "" : encodedPathAndQueryString) + ); + } + catch (MalformedURLException e) { + throw new IllegalArgumentException(e); + } + } + @Override public boolean equals(Object o) { @@ -143,4 +192,5 @@ public String toString() ", basePath='" + basePath + '\'' + '}'; } + }